use crate::tools::{PrimitiveToolName, Tool, ToolContext};
use crate::types::{ToolResult, ToolTier};
use anyhow::{Context, Result, bail};
use serde_json::{Value, json};
use std::time::Duration;
use super::security::UrlValidator;
const MAX_CONTENT_SIZE: usize = 1024 * 1024;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum FetchFormat {
#[default]
Text,
Markdown,
}
impl FetchFormat {
fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"text" => Some(Self::Text),
"markdown" | "md" => Some(Self::Markdown),
_ => None,
}
}
}
pub struct LinkFetchTool {
client: reqwest::Client,
validator: UrlValidator,
default_format: FetchFormat,
}
impl Default for LinkFetchTool {
fn default() -> Self {
Self::new()
}
}
impl LinkFetchTool {
#[must_use]
pub fn new() -> Self {
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.timeout(DEFAULT_TIMEOUT)
.user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)")
.build()
.expect("Failed to build HTTP client");
Self {
client,
validator: UrlValidator::new(),
default_format: FetchFormat::Text,
}
}
#[must_use]
pub fn with_validator(mut self, validator: UrlValidator) -> Self {
self.validator = validator;
self
}
#[must_use]
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.client = client;
self
}
#[must_use]
pub const fn with_default_format(mut self, format: FetchFormat) -> Self {
self.default_format = format;
self
}
async fn fetch_url(&self, url_str: &str, format: FetchFormat) -> Result<String> {
let mut url = self.validator.validate(url_str)?;
let max_redirects = self.validator.max_redirects();
let mut response = self
.client
.get(url.as_str())
.send()
.await
.context("Failed to fetch URL")?;
let mut redirects = 0;
while response.status().is_redirection() {
redirects += 1;
if redirects > max_redirects {
bail!("Too many redirects ({redirects} > {max_redirects})");
}
let location = response
.headers()
.get(reqwest::header::LOCATION)
.context("Redirect response missing Location header")?
.to_str()
.context("Invalid Location header")?;
let redirect_url_str = url
.join(location)
.map_or_else(|_| location.to_string(), |u| u.to_string());
url = self.validator.validate(&redirect_url_str)?;
response = self
.client
.get(url.as_str())
.send()
.await
.context("Failed to follow redirect")?;
}
if !response.status().is_success() {
bail!("HTTP error: {}", response.status());
}
if let Some(len) = response.content_length()
&& len > MAX_CONTENT_SIZE as u64
{
bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
}
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("text/html")
.to_string();
let bytes = response
.bytes()
.await
.context("Failed to read response body")?;
if bytes.len() > MAX_CONTENT_SIZE {
bail!(
"Content too large: {} bytes (max {} bytes)",
bytes.len(),
MAX_CONTENT_SIZE
);
}
let html = String::from_utf8_lossy(&bytes);
if content_type.contains("text/html") || content_type.contains("application/xhtml") {
Ok(convert_html(&html, format))
} else if content_type.contains("text/plain") {
Ok(html.into_owned())
} else {
Ok(html.into_owned())
}
}
}
fn convert_html(html: &str, format: FetchFormat) -> String {
let result = match format {
FetchFormat::Text => {
html2text::from_read(html.as_bytes(), 80)
}
FetchFormat::Markdown => {
html2text::from_read(html.as_bytes(), 80)
}
};
result.unwrap_or_else(|_| html.to_string())
}
impl<Ctx> Tool<Ctx> for LinkFetchTool
where
Ctx: Send + Sync + 'static,
{
type Name = PrimitiveToolName;
fn name(&self) -> PrimitiveToolName {
PrimitiveToolName::LinkFetch
}
fn display_name(&self) -> &'static str {
"Fetch URL"
}
fn description(&self) -> &'static str {
"Fetch and read web page content. Returns the page content as text or markdown. \
Includes SSRF protection to prevent access to internal resources."
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch (must be HTTPS)"
},
"format": {
"type": "string",
"enum": ["text", "markdown"],
"description": "Output format (default: text)"
}
},
"required": ["url"]
})
}
fn tier(&self) -> ToolTier {
ToolTier::Observe
}
async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
let url = input
.get("url")
.and_then(Value::as_str)
.context("Missing 'url' parameter")?;
let format = input
.get("format")
.and_then(Value::as_str)
.and_then(FetchFormat::from_str)
.unwrap_or(self.default_format);
match self.fetch_url(url, format).await {
Ok(content) => Ok(ToolResult {
success: true,
output: content,
data: Some(json!({ "url": url, "format": format_name(format) })),
documents: Vec::new(),
duration_ms: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: format!("Failed to fetch URL: {e}"),
data: Some(json!({ "url": url, "error": e.to_string() })),
documents: Vec::new(),
duration_ms: None,
}),
}
}
}
const fn format_name(format: FetchFormat) -> &'static str {
match format {
FetchFormat::Text => "text",
FetchFormat::Markdown => "markdown",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_link_fetch_tool_metadata() {
let tool = LinkFetchTool::new();
assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::LinkFetch);
assert!(Tool::<()>::description(&tool).contains("Fetch"));
assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
}
#[test]
fn test_link_fetch_tool_input_schema() {
let tool = LinkFetchTool::new();
let schema = Tool::<()>::input_schema(&tool);
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["url"].is_object());
assert!(schema["properties"]["format"].is_object());
assert!(
schema["required"]
.as_array()
.is_some_and(|arr| arr.iter().any(|v| v == "url"))
);
}
#[test]
fn test_format_from_str() {
assert_eq!(FetchFormat::from_str("text"), Some(FetchFormat::Text));
assert_eq!(FetchFormat::from_str("TEXT"), Some(FetchFormat::Text));
assert_eq!(
FetchFormat::from_str("markdown"),
Some(FetchFormat::Markdown)
);
assert_eq!(FetchFormat::from_str("md"), Some(FetchFormat::Markdown));
assert_eq!(FetchFormat::from_str("invalid"), None);
}
#[test]
fn test_convert_html_text() {
let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
let result = convert_html(html, FetchFormat::Text);
assert!(result.contains("Title"));
assert!(result.contains("Paragraph"));
}
#[test]
fn test_default_format() {
let tool = LinkFetchTool::new();
assert_eq!(tool.default_format, FetchFormat::Text);
let tool = LinkFetchTool::new().with_default_format(FetchFormat::Markdown);
assert_eq!(tool.default_format, FetchFormat::Markdown);
}
#[tokio::test]
async fn test_link_fetch_blocked_url() {
let tool = LinkFetchTool::new();
let ctx = ToolContext::new(());
let input = json!({ "url": "http://localhost:8080" });
let result = Tool::<()>::execute(&tool, &ctx, input).await;
assert!(result.is_ok());
let tool_result = result.expect("Should succeed");
assert!(!tool_result.success);
assert!(
tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
);
}
#[tokio::test]
async fn test_link_fetch_missing_url() {
let tool = LinkFetchTool::new();
let ctx = ToolContext::new(());
let input = json!({});
let result = Tool::<()>::execute(&tool, &ctx, input).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("url"));
}
#[tokio::test]
async fn test_link_fetch_invalid_url() {
let tool = LinkFetchTool::new();
let ctx = ToolContext::new(());
let input = json!({ "url": "not-a-valid-url" });
let result = Tool::<()>::execute(&tool, &ctx, input).await;
assert!(result.is_ok());
let tool_result = result.expect("Should succeed");
assert!(!tool_result.success);
assert!(tool_result.output.contains("Invalid URL"));
}
#[test]
fn test_with_validator() {
let validator = UrlValidator::new().with_allow_http();
let _tool = LinkFetchTool::new().with_validator(validator);
}
#[test]
fn test_format_name() {
assert_eq!(format_name(FetchFormat::Text), "text");
assert_eq!(format_name(FetchFormat::Markdown), "markdown");
}
#[test]
fn test_redirects_disabled_in_client() {
let tool = LinkFetchTool::new();
assert_eq!(tool.validator.max_redirects(), 3);
}
#[tokio::test]
async fn test_redirect_to_private_ip_blocked() {
let validator = UrlValidator::new().with_allow_http();
let result = validator.validate("http://169.254.169.254/latest/meta-data/");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("blocked"));
let result = validator.validate("http://10.0.0.1/internal");
assert!(result.is_err());
}
#[tokio::test]
async fn test_redirect_to_localhost_blocked() {
let validator = UrlValidator::new().with_allow_http();
let result = validator.validate("http://127.0.0.1/admin");
assert!(result.is_err());
}
}