use crate::tools::{PrimitiveToolName, Tool, ToolContext};
use crate::types::{ToolResult, ToolTier};
use anyhow::{Context, Result, bail};
use serde_json::{Value, json};
use std::time::Duration;
use super::security::UrlValidator;
const MAX_CONTENT_SIZE: usize = 1024 * 1024;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum FetchFormat {
#[default]
Text,
}
pub struct LinkFetchTool {
client: Option<reqwest::Client>,
validator: UrlValidator,
}
impl Default for LinkFetchTool {
fn default() -> Self {
Self::new()
}
}
impl LinkFetchTool {
#[must_use]
pub fn new() -> Self {
Self {
client: None,
validator: UrlValidator::new(),
}
}
#[must_use]
pub fn with_validator(mut self, validator: UrlValidator) -> Self {
self.validator = validator;
self
}
#[must_use]
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.client = Some(client);
self
}
fn build_client(
&self,
host: Option<&str>,
addrs: &[std::net::SocketAddr],
) -> Result<reqwest::Client> {
if let Some(client) = &self.client {
return Ok(client.clone());
}
let mut builder = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.timeout(DEFAULT_TIMEOUT)
.user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)");
if let Some(host) = host
&& !addrs.is_empty()
{
builder = builder.resolve_to_addrs(host, addrs);
}
builder.build().context("Failed to build HTTP client")
}
async fn fetch_url(&self, url_str: &str) -> Result<String> {
let mut validated = self.validator.validate(url_str).await?;
let max_redirects = self.validator.max_redirects();
let client = self.build_client(validated.url.host_str(), &validated.addresses)?;
let mut response = client
.get(validated.url.as_str())
.send()
.await
.context("Failed to fetch URL")?;
let mut redirects = 0;
while response.status().is_redirection() {
redirects += 1;
if redirects > max_redirects {
bail!("Too many redirects ({redirects} > {max_redirects})");
}
let location = response
.headers()
.get(reqwest::header::LOCATION)
.context("Redirect response missing Location header")?
.to_str()
.context("Invalid Location header")?;
let redirect_url_str = validated
.url
.join(location)
.map_or_else(|_| location.to_string(), |u| u.to_string());
validated = self.validator.validate(&redirect_url_str).await?;
let client = self.build_client(validated.url.host_str(), &validated.addresses)?;
response = client
.get(validated.url.as_str())
.send()
.await
.context("Failed to follow redirect")?;
}
if !response.status().is_success() {
bail!("HTTP error: {}", response.status());
}
if let Some(len) = response.content_length()
&& len > MAX_CONTENT_SIZE as u64
{
bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
}
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("text/html")
.to_string();
let bytes = read_capped_body(&mut response, MAX_CONTENT_SIZE).await?;
let html = String::from_utf8_lossy(&bytes);
if content_type.contains("text/html") || content_type.contains("application/xhtml") {
Ok(convert_html(&html))
} else if content_type.contains("text/plain") {
Ok(html.into_owned())
} else {
Ok(html.into_owned())
}
}
}
fn convert_html(html: &str) -> String {
html2text::from_read(html.as_bytes(), 80).unwrap_or_else(|_| html.to_string())
}
async fn read_capped_body(response: &mut reqwest::Response, max: usize) -> Result<Vec<u8>> {
let mut bytes: Vec<u8> = Vec::new();
while let Some(chunk) = response
.chunk()
.await
.context("Failed to read response body")?
{
if bytes.len() + chunk.len() > max {
bail!("Content too large: exceeds {max} bytes");
}
bytes.extend_from_slice(&chunk);
}
Ok(bytes)
}
impl<Ctx> Tool<Ctx> for LinkFetchTool
where
Ctx: Send + Sync + 'static,
{
type Name = PrimitiveToolName;
fn name(&self) -> PrimitiveToolName {
PrimitiveToolName::LinkFetch
}
fn display_name(&self) -> &'static str {
"Fetch URL"
}
fn description(&self) -> &'static str {
"Fetch and read web page content. Returns the page content as text or markdown. \
Includes SSRF protection to prevent access to internal resources."
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch (must be HTTPS)"
}
},
"required": ["url"]
})
}
fn tier(&self) -> ToolTier {
ToolTier::Observe
}
async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
let url = input
.get("url")
.and_then(Value::as_str)
.context("Missing 'url' parameter")?;
match self.fetch_url(url).await {
Ok(content) => Ok(ToolResult {
success: true,
output: content,
data: Some(json!({ "url": url })),
documents: Vec::new(),
duration_ms: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: format!("Failed to fetch URL: {e}"),
data: Some(json!({ "url": url, "error": e.to_string() })),
documents: Vec::new(),
duration_ms: None,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_link_fetch_tool_metadata() {
let tool = LinkFetchTool::new();
assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::LinkFetch);
assert!(Tool::<()>::description(&tool).contains("Fetch"));
assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
}
#[test]
fn test_link_fetch_tool_input_schema() {
let tool = LinkFetchTool::new();
let schema = Tool::<()>::input_schema(&tool);
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["url"].is_object());
assert!(schema["properties"]["format"].is_null());
assert!(
schema["required"]
.as_array()
.is_some_and(|arr| arr.iter().any(|v| v == "url"))
);
}
#[test]
fn test_convert_html_text() {
let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
let result = convert_html(html);
assert!(result.contains("Title"));
assert!(result.contains("Paragraph"));
}
#[tokio::test]
async fn test_link_fetch_blocked_url() {
let tool = LinkFetchTool::new();
let ctx = ToolContext::new(());
let input = json!({ "url": "http://localhost:8080" });
let result = Tool::<()>::execute(&tool, &ctx, input).await;
assert!(result.is_ok());
let tool_result = result.expect("Should succeed");
assert!(!tool_result.success);
assert!(
tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
);
}
#[tokio::test]
async fn test_link_fetch_missing_url() {
let tool = LinkFetchTool::new();
let ctx = ToolContext::new(());
let input = json!({});
let result = Tool::<()>::execute(&tool, &ctx, input).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("url"));
}
#[tokio::test]
async fn test_link_fetch_invalid_url() {
let tool = LinkFetchTool::new();
let ctx = ToolContext::new(());
let input = json!({ "url": "not-a-valid-url" });
let result = Tool::<()>::execute(&tool, &ctx, input).await;
assert!(result.is_ok());
let tool_result = result.expect("Should succeed");
assert!(!tool_result.success);
assert!(tool_result.output.contains("Invalid URL"));
}
#[test]
fn test_with_validator() {
let validator = UrlValidator::new().with_allow_http();
let _tool = LinkFetchTool::new().with_validator(validator);
}
#[test]
fn test_redirects_disabled_in_client() {
let tool = LinkFetchTool::new();
assert_eq!(tool.validator.max_redirects(), 3);
}
#[tokio::test]
async fn test_redirect_to_private_ip_blocked() {
let validator = UrlValidator::new().with_allow_http();
let result = validator
.validate("http://169.254.169.254/latest/meta-data/")
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("blocked"));
let result = validator.validate("http://10.0.0.1/internal").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_redirect_to_localhost_blocked() {
let validator = UrlValidator::new().with_allow_http();
let result = validator.validate("http://127.0.0.1/admin").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_read_capped_body_rejects_oversized_stream() -> Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let server = tokio::spawn(async move {
if let Ok((mut sock, _)) = listener.accept().await {
let mut buf = [0u8; 1024];
let _ = sock.read(&mut buf).await;
let header =
"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n";
let _ = sock.write_all(header.as_bytes()).await;
let chunk = vec![b'a'; 64 * 1024];
for _ in 0..40 {
if sock.write_all(&chunk).await.is_err() {
break;
}
}
let _ = sock.shutdown().await;
}
});
let client = reqwest::Client::builder().build()?;
let mut response = client.get(format!("http://{addr}/big")).send().await?;
let result = read_capped_body(&mut response, 1024 * 1024).await;
server.abort();
assert!(result.is_err(), "oversized streamed body must be rejected");
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Content too large"),
"expected size-cap error, got: {msg}"
);
Ok(())
}
#[tokio::test]
async fn test_read_capped_body_accepts_small_stream() -> Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let server = tokio::spawn(async move {
if let Ok((mut sock, _)) = listener.accept().await {
let mut buf = [0u8; 1024];
let _ = sock.read(&mut buf).await;
let body = "hello world";
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
body.len()
);
let _ = sock.write_all(resp.as_bytes()).await;
let _ = sock.shutdown().await;
}
});
let client = reqwest::Client::builder().build()?;
let mut response = client.get(format!("http://{addr}/small")).send().await?;
let bytes = read_capped_body(&mut response, 1024 * 1024).await?;
server.abort();
assert_eq!(String::from_utf8_lossy(&bytes), "hello world");
Ok(())
}
}