Skip to main content

ai_agent/tools/
web_fetch.rs

1use crate::types::*;
2use regex::Regex;
3use reqwest::Client;
4
5pub struct WebFetchTool {
6    client: Client,
7}
8
9impl WebFetchTool {
10    pub fn new() -> Self {
11        let client = Client::builder()
12            .timeout(std::time::Duration::from_secs(30))
13            .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)")
14            .build()
15            .expect("Failed to create HTTP client");
16        Self { client }
17    }
18
19    pub fn name(&self) -> &str {
20        "WebFetch"
21    }
22
23    pub fn description(&self) -> &str {
24        "Fetch content from a URL and return it as text. Supports HTML pages, JSON APIs, and plain text. Strips HTML tags for readability."
25    }
26
27    pub fn input_schema(&self) -> ToolInputSchema {
28        ToolInputSchema {
29            schema_type: "object".to_string(),
30            properties: serde_json::json!({
31                "url": {
32                    "type": "string",
33                    "description": "The URL to fetch content from"
34                },
35                "headers": {
36                    "type": "object",
37                    "description": "Optional HTTP headers",
38                    "additionalProperties": {
39                        "type": "string"
40                    }
41                }
42            }),
43            required: Some(vec!["url".to_string()]),
44        }
45    }
46
47    pub async fn execute(
48        &self,
49        input: serde_json::Value,
50        _context: &ToolContext,
51    ) -> Result<ToolResult, crate::error::AgentError> {
52        let url = input["url"]
53            .as_str()
54            .ok_or_else(|| crate::error::AgentError::Tool("url is required".to_string()))?;
55
56        // Build request with optional headers
57        let mut request = self.client.get(url);
58
59        if let Some(headers) = input["headers"].as_object() {
60            for (key, value) in headers {
61                if let Some(value_str) = value.as_str() {
62                    request = request.header(key, value_str);
63                }
64            }
65        }
66
67        let response = request.send().await.map_err(|e| {
68            crate::error::AgentError::Tool(format!("Error fetching {}: {}", url, e))
69        })?;
70
71        if !response.status().is_success() {
72            return Ok(ToolResult {
73                result_type: "text".to_string(),
74                tool_use_id: "".to_string(),
75                content: format!(
76                    "HTTP {}: {}",
77                    response.status().as_u16(),
78                    response.status().canonical_reason().unwrap_or("Unknown")
79                ),
80                is_error: Some(true),
81            });
82        }
83
84        let content_type = response
85            .headers()
86            .get("content-type")
87            .and_then(|v| v.to_str().ok())
88            .map(|s| s.to_string())
89            .unwrap_or_default();
90
91        let mut text = response.text().await.map_err(|e| {
92            crate::error::AgentError::Tool(format!("Error reading response: {}", e))
93        })?;
94
95        // Strip HTML tags for readability
96        if content_type.contains("text/html") {
97            // Remove script and style blocks
98            let script_regex = Regex::new(r"<script[^>]*>[\s\S]*?</script>").unwrap();
99            text = script_regex.replace_all(&text, "").to_string();
100
101            let style_regex = Regex::new(r"<style[^>]*>[\s\S]*?</style>").unwrap();
102            text = style_regex.replace_all(&text, "").to_string();
103
104            // Remove HTML tags
105            let tag_regex = Regex::new(r"<[^>]+>").unwrap();
106            text = tag_regex.replace_all(&text, " ").to_string();
107
108            // Clean up whitespace
109            let whitespace_regex = Regex::new(r"\s+").unwrap();
110            text = whitespace_regex.replace_all(&text, " ").trim().to_string();
111        }
112
113        // Truncate very large responses
114        if text.len() > 100000 {
115            text.truncate(100000);
116            text.push_str("\n...(truncated)");
117        }
118
119        if text.is_empty() {
120            text = "(empty response)".to_string();
121        }
122
123        Ok(ToolResult {
124            result_type: "text".to_string(),
125            tool_use_id: "".to_string(),
126            content: text,
127            is_error: None,
128        })
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_web_fetch_tool_name() {
138        let tool = WebFetchTool::new();
139        assert_eq!(tool.name(), "WebFetch");
140    }
141
142    #[test]
143    fn test_web_fetch_tool_description_contains_fetch() {
144        let tool = WebFetchTool::new();
145        assert!(tool.description().to_lowercase().contains("fetch"));
146    }
147
148    #[test]
149    fn test_web_fetch_tool_has_url_in_schema() {
150        let tool = WebFetchTool::new();
151        let schema = tool.input_schema();
152        assert!(schema.properties.get("url").is_some());
153    }
154
155    #[test]
156    fn test_web_fetch_tool_has_headers_in_schema() {
157        let tool = WebFetchTool::new();
158        let schema = tool.input_schema();
159        assert!(schema.properties.get("headers").is_some());
160    }
161
162    #[tokio::test]
163    async fn test_web_fetch_tool_requires_url() {
164        let tool = WebFetchTool::new();
165        let input = serde_json::json!({});
166        let context = ToolContext::default();
167
168        let result = tool.execute(input, &context).await;
169        assert!(result.is_err());
170    }
171
172    #[tokio::test]
173    #[ignore] // Requires network access to httpbin
174    async fn test_web_fetch_tool_fetches_plain_text() {
175        let tool = WebFetchTool::new();
176        // Using a simple echo endpoint
177        let input = serde_json::json!({
178            "url": "https://httpbin.org/robots.txt"
179        });
180        let context = ToolContext::default();
181
182        let result = tool.execute(input, &context).await;
183        assert!(result.is_ok());
184        let tool_result = result.unwrap();
185        assert!(!tool_result.content.is_empty());
186    }
187
188    #[tokio::test]
189    #[ignore] // Requires network access to httpbin
190    async fn test_web_fetch_tool_strips_html_tags() {
191        let tool = WebFetchTool::new();
192        // Using a simple HTML page
193        let input = serde_json::json!({
194            "url": "https://httpbin.org/html"
195        });
196        let context = ToolContext::default();
197
198        let result = tool.execute(input, &context).await;
199        assert!(result.is_ok());
200        let tool_result = result.unwrap();
201        // HTML tags should be stripped
202        assert!(!tool_result.content.contains("<html"));
203        assert!(!tool_result.content.contains("<body"));
204    }
205
206    #[tokio::test]
207    #[ignore] // Requires network access
208    async fn test_web_fetch_tool_returns_error_for_invalid_url() {
209        let tool = WebFetchTool::new();
210        let input = serde_json::json!({
211            "url": "https://this-domain-does-not-exist-123456.invalid/"
212        });
213        let context = ToolContext::default();
214
215        let result = tool.execute(input, &context).await;
216        assert!(result.is_err());
217    }
218}