use std::time::Duration;
use tokio_util::sync::CancellationToken;
use zagens_runtime_adapters::tools::{check_url_policy, is_http_url, is_restricted_ip};
use crate::tools::spec::{ToolContext, ToolError};
pub(crate) const MAX_REDIRECTS: usize = 5;
fn parse_host_as_ip(host: &str) -> Option<std::net::IpAddr> {
let trimmed = host.trim();
let literal = trimmed
.strip_prefix('[')
.and_then(|rest| rest.strip_suffix(']'))
.unwrap_or(trimmed);
literal.parse().ok()
}
pub(crate) fn ensure_not_cancelled(cancel: Option<&CancellationToken>) -> Result<(), ToolError> {
if cancel.is_some_and(CancellationToken::is_cancelled) {
Err(ToolError::execution_failed("request canceled"))
} else {
Ok(())
}
}
pub(crate) async fn read_body_capped(
mut resp: reqwest::Response,
max_bytes: usize,
cancel: Option<&CancellationToken>,
) -> Result<(Vec<u8>, bool), ToolError> {
let mut buf: Vec<u8> = Vec::new();
let mut truncated = false;
loop {
ensure_not_cancelled(cancel)?;
let chunk = if let Some(token) = cancel {
tokio::select! {
biased;
_ = token.cancelled() => {
return Err(ToolError::execution_failed("request canceled"));
}
chunk = resp.chunk() => chunk,
}
} else {
resp.chunk().await
};
let chunk =
chunk.map_err(|e| ToolError::execution_failed(format!("failed to read body: {e}")))?;
let Some(chunk) = chunk else { break };
let remaining = max_bytes.saturating_sub(buf.len());
if chunk.len() > remaining {
buf.extend_from_slice(&chunk[..remaining]);
truncated = true;
break;
}
buf.extend_from_slice(&chunk);
}
Ok((buf, truncated))
}
pub(crate) async fn validate_url_ssrf(
context: &ToolContext,
tool_name: &str,
url: &str,
) -> Result<Option<(String, std::net::IpAddr)>, ToolError> {
let host = check_url_policy(context.network_policy.as_ref(), tool_name, url)
.map_err(|e| ToolError::permission_denied(e.denial_message()))?;
let Some(host) = host else {
return Ok(None);
};
if host == "localhost" || host == "localhost.localdomain" {
return Err(ToolError::permission_denied(
"requests to localhost are not allowed",
));
}
if let Some(ip) = parse_host_as_ip(&host) {
if is_restricted_ip(&ip) {
return Err(ToolError::permission_denied(format!(
"IP {ip} is a restricted address (private/loopback/link-local)"
)));
}
return Ok(None);
}
let lookup: std::io::Result<Vec<std::net::SocketAddr>> =
tokio::net::lookup_host((host.as_str(), 0u16))
.await
.map(|addrs| addrs.collect());
match lookup {
Ok(addrs) => {
let mut first_valid: Option<std::net::IpAddr> = None;
for addr in addrs {
if is_restricted_ip(&addr.ip()) {
return Err(ToolError::permission_denied(format!(
"resolved IP {} is a restricted address (private/loopback/link-local)",
addr.ip()
)));
}
if first_valid.is_none() {
first_valid = Some(addr.ip());
}
}
match first_valid {
Some(ip) => Ok(Some((host, ip))),
None => Err(ToolError::execution_failed(format!(
"could not resolve host `{host}` (no addresses)"
))),
}
}
Err(e) => Err(ToolError::execution_failed(format!(
"DNS resolution failed for `{host}`: {e}"
))),
}
}
pub(crate) async fn fetch_with_ssrf_guard(
context: &ToolContext,
tool_name: &str,
initial_url: &str,
user_agent: &str,
accept: &str,
timeout_ms: u64,
) -> Result<reqwest::Response, ToolError> {
let mut current_url = initial_url.to_string();
let cancel = context.cancel_token.as_ref();
for _ in 0..=MAX_REDIRECTS {
ensure_not_cancelled(cancel)?;
let pinning = validate_url_ssrf(context, tool_name, ¤t_url).await?;
let mut builder = reqwest::Client::builder()
.timeout(Duration::from_millis(timeout_ms))
.user_agent(user_agent)
.redirect(reqwest::redirect::Policy::none());
if let Some((hostname, ip)) = &pinning {
builder = builder.resolve(hostname, std::net::SocketAddr::new(*ip, 0));
}
let client = builder.build().map_err(|e| {
ToolError::execution_failed(format!("failed to build HTTP client: {e}"))
})?;
let resp = if let Some(token) = cancel {
tokio::select! {
biased;
_ = token.cancelled() => {
return Err(ToolError::execution_failed("request canceled"));
}
resp = client
.get(¤t_url)
.header("Accept", accept)
.header("Accept-Language", "en-US,en;q=0.5")
.send() => resp,
}
} else {
client
.get(¤t_url)
.header("Accept", accept)
.header("Accept-Language", "en-US,en;q=0.5")
.send()
.await
}
.map_err(|e| ToolError::execution_failed(format!("request failed: {e}")))?;
if resp.status().is_redirection() {
let location = resp
.headers()
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
if let Some(location) = location {
let base = reqwest::Url::parse(¤t_url).map_err(|e| {
ToolError::execution_failed(format!("invalid current URL `{current_url}`: {e}"))
})?;
let next = base.join(&location).map_err(|e| {
ToolError::execution_failed(format!(
"invalid redirect location `{location}`: {e}"
))
})?;
if !is_http_url(next.as_str()) {
return Err(ToolError::permission_denied(format!(
"redirect to non-http(s) scheme blocked: {next}"
)));
}
current_url = next.to_string();
continue;
}
}
return Ok(resp);
}
Err(ToolError::execution_failed(format!(
"too many redirects (>{MAX_REDIRECTS})"
)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::spec::ToolContext;
use std::path::PathBuf;
fn ctx() -> ToolContext {
ToolContext::new(PathBuf::from("."))
}
#[tokio::test]
async fn rejects_cloud_metadata_ip() {
let err = validate_url_ssrf(
&ctx(),
"fetch_url",
"http://169.254.169.254/latest/meta-data/",
)
.await
.expect_err("metadata IP must be rejected");
assert!(format!("{err}").contains("restricted"));
}
#[tokio::test]
async fn rejects_private_and_loopback_ips() {
for url in [
"http://127.0.0.1/",
"http://10.0.0.5/",
"http://192.168.1.1/admin",
"http://[::1]/",
] {
let err = validate_url_ssrf(&ctx(), "fetch_url", url)
.await
.unwrap_err();
assert!(
format!("{err}").contains("restricted"),
"{url} should be rejected as restricted"
);
}
}
#[tokio::test]
async fn rejects_localhost_hostname() {
let err = validate_url_ssrf(&ctx(), "fetch_url", "http://localhost:8080/")
.await
.unwrap_err();
assert!(format!("{err}").contains("localhost"));
}
#[tokio::test]
async fn dns_failure_fails_closed() {
let err = validate_url_ssrf(&ctx(), "fetch_url", "http://no-such-host.invalid/")
.await
.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("DNS resolution failed")
|| msg.contains("could not resolve")
|| msg.contains("restricted"),
"unresolvable host must be blocked, got: {msg}"
);
}
#[tokio::test]
async fn allows_public_literal_ip() {
let res = validate_url_ssrf(&ctx(), "fetch_url", "http://8.8.8.8/")
.await
.expect("public IP allowed");
assert!(res.is_none());
}
}