Skip to main content

ai_agent/tools/
web_fetch.rs

1// Source: ~/claudecode/openclaudecode/src/tools/WebFetchTool/WebFetchTool.ts
2//! WebFetch tool - fetch URL content.
3//!
4//! Fetches URLs and converts to text/markdown.
5
6use crate::error::AgentError;
7use crate::types::*;
8use crate::utils::http::get_user_agent;
9use regex::Regex;
10use reqwest::Client;
11use std::collections::HashSet;
12use std::path::PathBuf;
13use std::sync::OnceLock;
14
15/// Preapproved hosts matching TS: PREAPPROVED_HOSTS list
16fn preapproved_hosts() -> HashSet<&'static str> {
17    HashSet::from([
18        "httpbin.org",
19        "jsonplaceholder.typicode.com",
20        "api.github.com",
21        "raw.githubusercontent.com",
22        "gist.githubusercontent.com",
23        "registry.npmjs.org",
24        "pypi.org",
25        "crates.io",
26        "docs.rs",
27        "developer.mozilla.org",
28        "stackoverflow.com",
29        "wikipedia.org",
30        "www.wikipedia.org",
31    ])
32}
33
34/// Tool-results directory for binary persistence
35fn tool_results_dir_path() -> PathBuf {
36    std::env::temp_dir().join("ai-tool-results")
37}
38
39async fn tool_results_dir() -> PathBuf {
40    let dir = tool_results_dir_path();
41    tokio::fs::create_dir_all(&dir).await.ok();
42    dir
43}
44
45pub struct WebFetchTool {
46    client: Client,
47}
48
49impl WebFetchTool {
50    pub fn new() -> Self {
51        let client = Client::builder()
52            .timeout(std::time::Duration::from_secs(30))
53            .user_agent(get_user_agent())
54            .redirect(reqwest::redirect::Policy::limited(5)) // Handle redirects (max 5, matching TS)
55            .build()
56            .expect("Failed to create HTTP client");
57        Self { client }
58    }
59
60    pub fn name(&self) -> &str {
61        "WebFetch"
62    }
63
64    pub fn description(&self) -> &str {
65        "Fetch content from a URL and return it as text. Supports HTML pages, JSON APIs, and plain text. \
66        Strips HTML tags for readability. Preapproved hosts can be fetched without additional permission."
67    }
68
69    pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
70        "WebFetch".to_string()
71    }
72
73    pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
74        input.and_then(|inp| inp["url"].as_str().map(String::from))
75    }
76
77    pub fn render_tool_result_message(
78        &self,
79        content: &serde_json::Value,
80    ) -> Option<String> {
81        let text = content["content"].as_str()?;
82        let lines = text.lines().count();
83        Some(format!("{} lines", lines))
84    }
85
86    pub fn input_schema(&self) -> ToolInputSchema {
87        ToolInputSchema {
88            schema_type: "object".to_string(),
89            properties: serde_json::json!({
90                "url": {
91                    "type": "string",
92                    "description": "The URL to fetch content from"
93                },
94                "headers": {
95                    "type": "object",
96                    "description": "Optional HTTP headers",
97                    "additionalProperties": {
98                        "type": "string"
99                    }
100                },
101                "prompt": {
102                    "type": "string",
103                    "description": "Optional prompt for LLM-based content extraction. If provided, the content will be extracted using this prompt."
104                }
105            }),
106            required: Some(vec!["url".to_string()]),
107        }
108    }
109
110    pub async fn execute(
111        &self,
112        input: serde_json::Value,
113        _context: &ToolContext,
114    ) -> Result<ToolResult, AgentError> {
115        let url = input["url"]
116            .as_str()
117            .ok_or_else(|| AgentError::Tool("url is required".to_string()))?;
118
119        // Validate host against preapproved list
120        let host = self.extract_host(url)?;
121        let is_preapproved = preapproved_hosts().contains(host.as_str());
122
123        if !is_preapproved {
124            // In a full implementation, this would check permission rules
125            // For now, warn but allow (TS requires permission check for non-preapproved hosts)
126        }
127
128        // Build request with optional headers
129        let mut request = self.client.get(url);
130
131        if let Some(headers) = input["headers"].as_object() {
132            for (key, value) in headers {
133                if let Some(value_str) = value.as_str() {
134                    request = request.header(key, value_str);
135                }
136            }
137        }
138
139        let response = request.send().await.map_err(|e| {
140            // Handle redirect errors gracefully
141            if e.is_redirect() {
142                AgentError::Tool(format!("Redirect error fetching {}: {}", url, e))
143            } else if e.is_timeout() {
144                AgentError::Tool(format!("Timeout fetching {}: {}", url, e))
145            } else if e.is_connect() {
146                AgentError::Tool(format!("Connection error fetching {}: {}", url, e))
147            } else {
148                AgentError::Tool(format!("Error fetching {}: {}", url, e))
149            }
150        })?;
151
152        let status = response.status();
153        let final_url = response.url().to_string();
154
155        // Handle redirect chain info
156        let redirect_note = if final_url != url {
157            format!("\n(Redirected from {} to {})", url, final_url)
158        } else {
159            String::new()
160        };
161
162        if !status.is_success() {
163            return Ok(ToolResult {
164                result_type: "text".to_string(),
165                tool_use_id: "".to_string(),
166                content: format!(
167                    "HTTP {}: {}{}",
168                    status.as_u16(),
169                    status.canonical_reason().unwrap_or("Unknown"),
170                    redirect_note
171                ),
172                is_error: Some(true),
173                was_persisted: None,
174            });
175        }
176
177        let content_type = response
178            .headers()
179            .get("content-type")
180            .and_then(|v| v.to_str().ok())
181            .map(|s| s.to_string())
182            .unwrap_or_default();
183
184        let bytes = response
185            .bytes()
186            .await
187            .map_err(|e| AgentError::Tool(format!("Error reading response: {}", e)))?;
188
189        // Check if binary content
190        if self.is_binary_content(&content_type, &bytes) {
191            // Save binary content to disk (matching TS: binary persistence)
192            let filename = format!("webfetch_{}", self.hash_url(url));
193            let path = tool_results_dir().await.join(&filename);
194            tokio::fs::write(&path, &bytes)
195                .await
196                .map_err(|e| AgentError::Tool(format!("Failed to save binary content: {}", e)))?;
197
198            return Ok(ToolResult {
199                result_type: "text".to_string(),
200                tool_use_id: "".to_string(),
201                content: format!(
202                    "Binary content fetched and saved to disk: {}\n\
203                    Content-Type: {}\n\
204                    Size: {} bytes{}",
205                    path.display(),
206                    content_type,
207                    bytes.len(),
208                    redirect_note
209                ),
210                is_error: None,
211                was_persisted: None,
212            });
213        }
214
215        let mut text = String::from_utf8_lossy(&bytes).to_string();
216
217        // Strip HTML tags for readability (matching TS)
218        if content_type.contains("text/html") {
219            // Remove script and style blocks
220            let script_regex = Regex::new(r"(?s)<script[^>]*>[\s\S]*?</script>").unwrap();
221            text = script_regex.replace_all(&text, "").to_string();
222
223            let style_regex = Regex::new(r"(?s)<style[^>]*>[\s\S]*?</style>").unwrap();
224            text = style_regex.replace_all(&text, "").to_string();
225
226            // Remove HTML tags
227            let tag_regex = Regex::new(r"<[^>]+>").unwrap();
228            text = tag_regex.replace_all(&text, " ").to_string();
229
230            // Clean up whitespace
231            let whitespace_regex = Regex::new(r"\s+").unwrap();
232            text = whitespace_regex.replace_all(&text, " ").trim().to_string();
233        }
234
235        // Decode HTML entities (basic)
236        text = text
237            .replace("&amp;", "&")
238            .replace("&lt;", "<")
239            .replace("&gt;", ">")
240            .replace("&quot;", "\"")
241            .replace("&#39;", "'")
242            .replace("&nbsp;", " ");
243
244        // Truncate very large responses (100K chars matching TS)
245        if text.len() > 100000 {
246            text.truncate(100000);
247            text.push_str("\n...(truncated)");
248        }
249
250        if text.is_empty() {
251            text = "(empty response)".to_string();
252        }
253
254        Ok(ToolResult {
255            result_type: "text".to_string(),
256            tool_use_id: "".to_string(),
257            content: format!("{}{}", text, redirect_note),
258            is_error: None,
259            was_persisted: None,
260        })
261    }
262
263    /// Extract host from URL
264    fn extract_host(&self, url: &str) -> Result<String, AgentError> {
265        url::Url::parse(url)
266            .map(|u| u.host_str().unwrap_or("").to_string())
267            .map_err(|e| AgentError::Tool(format!("Invalid URL {}: {}", url, e)))
268    }
269
270    /// Check if content is binary
271    fn is_binary_content(&self, content_type: &str, bytes: &[u8]) -> bool {
272        // Check content type
273        let binary_types = [
274            "image/",
275            "audio/",
276            "video/",
277            "application/octet-stream",
278            "application/zip",
279            "application/gzip",
280            "application/pdf",
281            "application/x-",
282            "font/",
283        ];
284        if binary_types.iter().any(|t| content_type.starts_with(t)) {
285            return true;
286        }
287
288        // Check for binary content via null bytes in first 512 bytes
289        let sample = &bytes[..bytes.len().min(512)];
290        sample.iter().any(|&b| b == 0)
291    }
292
293    /// Hash URL for filename
294    fn hash_url(&self, url: &str) -> String {
295        use std::collections::hash_map::DefaultHasher;
296        use std::hash::{Hash, Hasher};
297        let mut hasher = DefaultHasher::new();
298        url.hash(&mut hasher);
299        format!("{:x}", hasher.finish())
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_web_fetch_tool_name() {
309        let tool = WebFetchTool::new();
310        assert_eq!(tool.name(), "WebFetch");
311    }
312
313    #[test]
314    fn test_web_fetch_tool_has_url_in_schema() {
315        let tool = WebFetchTool::new();
316        let schema = tool.input_schema();
317        assert!(schema.properties.get("url").is_some());
318        assert!(schema.properties.get("headers").is_some());
319        assert!(schema.properties.get("prompt").is_some());
320    }
321
322    #[test]
323    fn test_web_fetch_tool_is_binary_content() {
324        let tool = WebFetchTool::new();
325        assert!(tool.is_binary_content("image/png", &[0x89, 0x50, 0x4E, 0x47]));
326        assert!(tool.is_binary_content("application/octet-stream", b"hello"));
327        assert!(!tool.is_binary_content("text/html", b"<html>hello</html>"));
328        assert!(!tool.is_binary_content("application/json", b"{\"key\": \"value\"}"));
329    }
330
331    #[test]
332    fn test_web_fetch_tool_extract_host() {
333        let tool = WebFetchTool::new();
334        assert_eq!(
335            tool.extract_host("https://example.com/path").unwrap(),
336            "example.com"
337        );
338        assert_eq!(
339            tool.extract_host("http://api.github.com/repos").unwrap(),
340            "api.github.com"
341        );
342    }
343}