1use crate::error::AgentError;
7use crate::types::*;
8use crate::utils::http::get_user_agent;
9use regex::Regex;
10use reqwest::Client;
11use std::collections::HashSet;
12use std::path::PathBuf;
13use std::sync::OnceLock;
14
15fn preapproved_hosts() -> HashSet<&'static str> {
17 HashSet::from([
18 "httpbin.org",
19 "jsonplaceholder.typicode.com",
20 "api.github.com",
21 "raw.githubusercontent.com",
22 "gist.githubusercontent.com",
23 "registry.npmjs.org",
24 "pypi.org",
25 "crates.io",
26 "docs.rs",
27 "developer.mozilla.org",
28 "stackoverflow.com",
29 "wikipedia.org",
30 "www.wikipedia.org",
31 ])
32}
33
34fn tool_results_dir_path() -> PathBuf {
36 std::env::temp_dir().join("ai-tool-results")
37}
38
39async fn tool_results_dir() -> PathBuf {
40 let dir = tool_results_dir_path();
41 tokio::fs::create_dir_all(&dir).await.ok();
42 dir
43}
44
45pub struct WebFetchTool {
46 client: Client,
47}
48
49impl WebFetchTool {
50 pub fn new() -> Self {
51 let client = Client::builder()
52 .timeout(std::time::Duration::from_secs(30))
53 .user_agent(get_user_agent())
54 .redirect(reqwest::redirect::Policy::limited(5)) .build()
56 .expect("Failed to create HTTP client");
57 Self { client }
58 }
59
60 pub fn name(&self) -> &str {
61 "WebFetch"
62 }
63
64 pub fn description(&self) -> &str {
65 "Fetch content from a URL and return it as text. Supports HTML pages, JSON APIs, and plain text. \
66 Strips HTML tags for readability. Preapproved hosts can be fetched without additional permission."
67 }
68
69 pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
70 "WebFetch".to_string()
71 }
72
73 pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
74 input.and_then(|inp| inp["url"].as_str().map(String::from))
75 }
76
77 pub fn render_tool_result_message(
78 &self,
79 content: &serde_json::Value,
80 ) -> Option<String> {
81 let text = content["content"].as_str()?;
82 let lines = text.lines().count();
83 Some(format!("{} lines", lines))
84 }
85
86 pub fn input_schema(&self) -> ToolInputSchema {
87 ToolInputSchema {
88 schema_type: "object".to_string(),
89 properties: serde_json::json!({
90 "url": {
91 "type": "string",
92 "description": "The URL to fetch content from"
93 },
94 "headers": {
95 "type": "object",
96 "description": "Optional HTTP headers",
97 "additionalProperties": {
98 "type": "string"
99 }
100 },
101 "prompt": {
102 "type": "string",
103 "description": "Optional prompt for LLM-based content extraction. If provided, the content will be extracted using this prompt."
104 }
105 }),
106 required: Some(vec!["url".to_string()]),
107 }
108 }
109
110 pub async fn execute(
111 &self,
112 input: serde_json::Value,
113 _context: &ToolContext,
114 ) -> Result<ToolResult, AgentError> {
115 let url = input["url"]
116 .as_str()
117 .ok_or_else(|| AgentError::Tool("url is required".to_string()))?;
118
119 let host = self.extract_host(url)?;
121 let is_preapproved = preapproved_hosts().contains(host.as_str());
122
123 if !is_preapproved {
124 }
127
128 let mut request = self.client.get(url);
130
131 if let Some(headers) = input["headers"].as_object() {
132 for (key, value) in headers {
133 if let Some(value_str) = value.as_str() {
134 request = request.header(key, value_str);
135 }
136 }
137 }
138
139 let response = request.send().await.map_err(|e| {
140 if e.is_redirect() {
142 AgentError::Tool(format!("Redirect error fetching {}: {}", url, e))
143 } else if e.is_timeout() {
144 AgentError::Tool(format!("Timeout fetching {}: {}", url, e))
145 } else if e.is_connect() {
146 AgentError::Tool(format!("Connection error fetching {}: {}", url, e))
147 } else {
148 AgentError::Tool(format!("Error fetching {}: {}", url, e))
149 }
150 })?;
151
152 let status = response.status();
153 let final_url = response.url().to_string();
154
155 let redirect_note = if final_url != url {
157 format!("\n(Redirected from {} to {})", url, final_url)
158 } else {
159 String::new()
160 };
161
162 if !status.is_success() {
163 return Ok(ToolResult {
164 result_type: "text".to_string(),
165 tool_use_id: "".to_string(),
166 content: format!(
167 "HTTP {}: {}{}",
168 status.as_u16(),
169 status.canonical_reason().unwrap_or("Unknown"),
170 redirect_note
171 ),
172 is_error: Some(true),
173 was_persisted: None,
174 });
175 }
176
177 let content_type = response
178 .headers()
179 .get("content-type")
180 .and_then(|v| v.to_str().ok())
181 .map(|s| s.to_string())
182 .unwrap_or_default();
183
184 let bytes = response
185 .bytes()
186 .await
187 .map_err(|e| AgentError::Tool(format!("Error reading response: {}", e)))?;
188
189 if self.is_binary_content(&content_type, &bytes) {
191 let filename = format!("webfetch_{}", self.hash_url(url));
193 let path = tool_results_dir().await.join(&filename);
194 tokio::fs::write(&path, &bytes)
195 .await
196 .map_err(|e| AgentError::Tool(format!("Failed to save binary content: {}", e)))?;
197
198 return Ok(ToolResult {
199 result_type: "text".to_string(),
200 tool_use_id: "".to_string(),
201 content: format!(
202 "Binary content fetched and saved to disk: {}\n\
203 Content-Type: {}\n\
204 Size: {} bytes{}",
205 path.display(),
206 content_type,
207 bytes.len(),
208 redirect_note
209 ),
210 is_error: None,
211 was_persisted: None,
212 });
213 }
214
215 let mut text = String::from_utf8_lossy(&bytes).to_string();
216
217 if content_type.contains("text/html") {
219 let script_regex = Regex::new(r"(?s)<script[^>]*>[\s\S]*?</script>").unwrap();
221 text = script_regex.replace_all(&text, "").to_string();
222
223 let style_regex = Regex::new(r"(?s)<style[^>]*>[\s\S]*?</style>").unwrap();
224 text = style_regex.replace_all(&text, "").to_string();
225
226 let tag_regex = Regex::new(r"<[^>]+>").unwrap();
228 text = tag_regex.replace_all(&text, " ").to_string();
229
230 let whitespace_regex = Regex::new(r"\s+").unwrap();
232 text = whitespace_regex.replace_all(&text, " ").trim().to_string();
233 }
234
235 text = text
237 .replace("&", "&")
238 .replace("<", "<")
239 .replace(">", ">")
240 .replace(""", "\"")
241 .replace("'", "'")
242 .replace(" ", " ");
243
244 if text.len() > 100000 {
246 text.truncate(100000);
247 text.push_str("\n...(truncated)");
248 }
249
250 if text.is_empty() {
251 text = "(empty response)".to_string();
252 }
253
254 Ok(ToolResult {
255 result_type: "text".to_string(),
256 tool_use_id: "".to_string(),
257 content: format!("{}{}", text, redirect_note),
258 is_error: None,
259 was_persisted: None,
260 })
261 }
262
263 fn extract_host(&self, url: &str) -> Result<String, AgentError> {
265 url::Url::parse(url)
266 .map(|u| u.host_str().unwrap_or("").to_string())
267 .map_err(|e| AgentError::Tool(format!("Invalid URL {}: {}", url, e)))
268 }
269
270 fn is_binary_content(&self, content_type: &str, bytes: &[u8]) -> bool {
272 let binary_types = [
274 "image/",
275 "audio/",
276 "video/",
277 "application/octet-stream",
278 "application/zip",
279 "application/gzip",
280 "application/pdf",
281 "application/x-",
282 "font/",
283 ];
284 if binary_types.iter().any(|t| content_type.starts_with(t)) {
285 return true;
286 }
287
288 let sample = &bytes[..bytes.len().min(512)];
290 sample.iter().any(|&b| b == 0)
291 }
292
293 fn hash_url(&self, url: &str) -> String {
295 use std::collections::hash_map::DefaultHasher;
296 use std::hash::{Hash, Hasher};
297 let mut hasher = DefaultHasher::new();
298 url.hash(&mut hasher);
299 format!("{:x}", hasher.finish())
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_web_fetch_tool_name() {
309 let tool = WebFetchTool::new();
310 assert_eq!(tool.name(), "WebFetch");
311 }
312
313 #[test]
314 fn test_web_fetch_tool_has_url_in_schema() {
315 let tool = WebFetchTool::new();
316 let schema = tool.input_schema();
317 assert!(schema.properties.get("url").is_some());
318 assert!(schema.properties.get("headers").is_some());
319 assert!(schema.properties.get("prompt").is_some());
320 }
321
322 #[test]
323 fn test_web_fetch_tool_is_binary_content() {
324 let tool = WebFetchTool::new();
325 assert!(tool.is_binary_content("image/png", &[0x89, 0x50, 0x4E, 0x47]));
326 assert!(tool.is_binary_content("application/octet-stream", b"hello"));
327 assert!(!tool.is_binary_content("text/html", b"<html>hello</html>"));
328 assert!(!tool.is_binary_content("application/json", b"{\"key\": \"value\"}"));
329 }
330
331 #[test]
332 fn test_web_fetch_tool_extract_host() {
333 let tool = WebFetchTool::new();
334 assert_eq!(
335 tool.extract_host("https://example.com/path").unwrap(),
336 "example.com"
337 );
338 assert_eq!(
339 tool.extract_host("http://api.github.com/repos").unwrap(),
340 "api.github.com"
341 );
342 }
343}