ricecoder_tools/
webfetch.rs1use crate::error::ToolError;
6use crate::result::ToolResult;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9use std::str::FromStr;
10use std::time::Instant;
11use tracing::{debug, warn};
12use url::Url;
13
14const MAX_CONTENT_SIZE: usize = 50 * 1024;
16
17const REQUEST_TIMEOUT_SECS: u64 = 10;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct WebfetchInput {
23 pub url: String,
25 pub max_size: Option<usize>,
27}
28
29impl WebfetchInput {
30 pub fn new(url: impl Into<String>) -> Self {
32 Self {
33 url: url.into(),
34 max_size: None,
35 }
36 }
37
38 pub fn with_max_size(mut self, max_size: usize) -> Self {
40 self.max_size = Some(max_size);
41 self
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct WebfetchOutput {
48 pub content: String,
50 pub truncated: bool,
52 pub original_size: usize,
54 pub returned_size: usize,
56}
57
58impl WebfetchOutput {
59 pub fn new(content: String, original_size: usize) -> Self {
61 let returned_size = content.len();
62 let truncated = returned_size < original_size;
63
64 Self {
65 content,
66 truncated,
67 original_size,
68 returned_size,
69 }
70 }
71}
72
73pub struct WebfetchTool {
75 client: reqwest::Client,
76}
77
78impl WebfetchTool {
79 pub fn new() -> Result<Self, ToolError> {
81 let client = reqwest::Client::builder()
82 .timeout(std::time::Duration::from_secs(REQUEST_TIMEOUT_SECS))
83 .build()
84 .map_err(|e| {
85 ToolError::new("CLIENT_ERROR", "Failed to create HTTP client")
86 .with_details(e.to_string())
87 })?;
88
89 Ok(Self { client })
90 }
91
92 pub fn validate_url(url: &str) -> Result<Url, ToolError> {
94 let parsed_url = Url::parse(url).map_err(|e| {
96 ToolError::new("INVALID_URL", "Invalid URL format")
97 .with_details(e.to_string())
98 .with_suggestion("Ensure URL is properly formatted (e.g., https://example.com)")
99 })?;
100
101 match parsed_url.scheme() {
103 "http" | "https" => {}
104 _ => {
105 return Err(ToolError::new("INVALID_SCHEME", "Only http and https schemes are supported")
106 .with_suggestion("Use http:// or https:// URLs"))
107 }
108 }
109
110 if let Some(host) = parsed_url.host_str() {
112 if host == "localhost" || host == "127.0.0.1" || host == "::1" {
114 return Err(ToolError::new("SSRF_PREVENTION", "Localhost URLs are not allowed")
115 .with_suggestion("Use a public URL instead"));
116 }
117
118 if let Ok(ip) = IpAddr::from_str(host) {
120 let is_private = match ip {
121 IpAddr::V4(ipv4) => ipv4.is_private() || ipv4.is_loopback(),
122 IpAddr::V6(ipv6) => ipv6.is_loopback(),
123 };
124
125 if is_private {
126 return Err(ToolError::new("SSRF_PREVENTION", "Private IP addresses are not allowed")
127 .with_details(format!("IP: {}", ip))
128 .with_suggestion("Use a public URL instead"));
129 }
130 }
131 }
132
133 Ok(parsed_url)
134 }
135
136 pub async fn fetch(&self, input: WebfetchInput) -> ToolResult<WebfetchOutput> {
138 let start = Instant::now();
139
140 match Self::validate_url(&input.url) {
142 Ok(_) => {
143 debug!("URL validation passed: {}", input.url);
144 }
145 Err(e) => {
146 let duration_ms = start.elapsed().as_millis() as u64;
147 return ToolResult::err(e, duration_ms, "builtin");
148 }
149 }
150
151 let max_size = input.max_size.unwrap_or(MAX_CONTENT_SIZE);
153
154 let timeout_duration = std::time::Duration::from_secs(REQUEST_TIMEOUT_SECS);
156 let fetch_future = async {
157 match self.client.get(&input.url).send().await {
159 Ok(response) => {
160 if !response.status().is_success() {
162 let error = ToolError::new(
163 "HTTP_ERROR",
164 format!("HTTP error: {}", response.status()),
165 )
166 .with_details(format!("Status code: {}", response.status().as_u16()))
167 .with_suggestion("Check the URL and try again");
168 return Err(error);
169 }
170
171 match response.bytes().await {
173 Ok(bytes) => {
174 let original_size = bytes.len();
175
176 let content = if original_size > max_size {
178 warn!(
179 "Content truncated from {} to {} bytes",
180 original_size, max_size
181 );
182 String::from_utf8_lossy(&bytes[..max_size]).to_string()
183 } else {
184 String::from_utf8_lossy(&bytes).to_string()
185 };
186
187 let output = WebfetchOutput::new(content, original_size);
188 Ok(output)
189 }
190 Err(e) => {
191 let error = ToolError::from(e);
192 Err(error)
193 }
194 }
195 }
196 Err(e) => {
197 let error = ToolError::from(e);
198 Err(error)
199 }
200 }
201 };
202
203 match tokio::time::timeout(timeout_duration, fetch_future).await {
204 Ok(Ok(output)) => {
205 let duration_ms = start.elapsed().as_millis() as u64;
206 ToolResult::ok(output, duration_ms, "builtin")
207 }
208 Ok(Err(error)) => {
209 let duration_ms = start.elapsed().as_millis() as u64;
210 ToolResult::err(error, duration_ms, "builtin")
211 }
212 Err(_) => {
213 let duration_ms = start.elapsed().as_millis() as u64;
214 let error = ToolError::new("TIMEOUT", "Webfetch operation exceeded 10 second timeout")
215 .with_details(format!("URL: {}", input.url))
216 .with_suggestion("Try again with a different URL or check your network connection");
217 ToolResult::err(error, duration_ms, "builtin")
218 }
219 }
220 }
221}
222
223impl Default for WebfetchTool {
224 fn default() -> Self {
225 Self::new().expect("Failed to create default WebfetchTool")
226 }
227}
228
229pub struct WebfetchToolWithMcp {
231 builtin: WebfetchTool,
232 mcp_client: Option<ricecoder_mcp::MCPClient>,
233}
234
235impl WebfetchToolWithMcp {
236 pub fn new(mcp_client: Option<ricecoder_mcp::MCPClient>) -> Result<Self, ToolError> {
238 let builtin = WebfetchTool::new()?;
239 Ok(Self { builtin, mcp_client })
240 }
241
242 async fn is_mcp_available(&self) -> bool {
244 if let Some(client) = &self.mcp_client {
245 if let Ok(servers) = client.discover_servers().await {
247 for server_id in servers {
248 if let Ok(tools) = client.discover_tools(&server_id).await {
249 if tools.iter().any(|t| t.id.contains("webfetch")) {
250 debug!("MCP webfetch server available: {}", server_id);
251 return true;
252 }
253 }
254 }
255 }
256 }
257 false
258 }
259
260 pub async fn fetch(&self, input: WebfetchInput) -> ToolResult<WebfetchOutput> {
262 if self.is_mcp_available().await {
264 debug!("Attempting to use MCP webfetch provider");
265 }
268
269 debug!("Using built-in webfetch provider");
271 self.builtin.fetch(input).await
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_webfetch_input_creation() {
281 let input = WebfetchInput::new("https://example.com");
282 assert_eq!(input.url, "https://example.com");
283 assert!(input.max_size.is_none());
284 }
285
286 #[test]
287 fn test_webfetch_input_with_max_size() {
288 let input = WebfetchInput::new("https://example.com").with_max_size(1024);
289 assert_eq!(input.max_size, Some(1024));
290 }
291
292 #[test]
293 fn test_webfetch_output_creation() {
294 let output = WebfetchOutput::new("test content".to_string(), 12);
295 assert_eq!(output.content, "test content");
296 assert_eq!(output.original_size, 12);
297 assert_eq!(output.returned_size, 12);
298 assert!(!output.truncated);
299 }
300
301 #[test]
302 fn test_webfetch_output_truncation() {
303 let output = WebfetchOutput::new("test".to_string(), 100);
304 assert_eq!(output.content, "test");
305 assert_eq!(output.original_size, 100);
306 assert_eq!(output.returned_size, 4);
307 assert!(output.truncated);
308 }
309
310 #[test]
311 fn test_validate_url_valid_https() {
312 let result = WebfetchTool::validate_url("https://example.com");
313 assert!(result.is_ok());
314 }
315
316 #[test]
317 fn test_validate_url_valid_http() {
318 let result = WebfetchTool::validate_url("http://example.com");
319 assert!(result.is_ok());
320 }
321
322 #[test]
323 fn test_validate_url_invalid_format() {
324 let result = WebfetchTool::validate_url("not a url");
325 assert!(result.is_err());
326 if let Err(e) = result {
327 assert_eq!(e.code, "INVALID_URL");
328 }
329 }
330
331 #[test]
332 fn test_validate_url_invalid_scheme() {
333 let result = WebfetchTool::validate_url("ftp://example.com");
334 assert!(result.is_err());
335 if let Err(e) = result {
336 assert_eq!(e.code, "INVALID_SCHEME");
337 }
338 }
339
340 #[test]
341 fn test_validate_url_localhost_rejection() {
342 let result = WebfetchTool::validate_url("http://localhost:8080");
343 assert!(result.is_err());
344 if let Err(e) = result {
345 assert_eq!(e.code, "SSRF_PREVENTION");
346 }
347 }
348
349 #[test]
350 fn test_validate_url_127_0_0_1_rejection() {
351 let result = WebfetchTool::validate_url("http://127.0.0.1");
352 assert!(result.is_err());
353 if let Err(e) = result {
354 assert_eq!(e.code, "SSRF_PREVENTION");
355 }
356 }
357
358 #[test]
359 fn test_validate_url_private_ip_rejection() {
360 let result = WebfetchTool::validate_url("http://192.168.1.1");
361 assert!(result.is_err());
362 if let Err(e) = result {
363 assert_eq!(e.code, "SSRF_PREVENTION");
364 }
365 }
366
367 #[test]
368 fn test_webfetch_tool_creation() {
369 let tool = WebfetchTool::new();
370 assert!(tool.is_ok());
371 }
372
373 #[tokio::test]
374 async fn test_webfetch_timeout_enforcement() {
375 let tool = WebfetchTool::new().unwrap();
378 let input = WebfetchInput::new("http://httpbin.org/delay/15"); let result = tool.fetch(input).await;
381
382 assert!(!result.success);
384 assert!(result.error.is_some());
385 if let Some(error) = result.error {
386 assert!(error.code == "TIMEOUT" || error.code == "HTTP_ERROR",
388 "Expected TIMEOUT or HTTP_ERROR, got: {}", error.code);
389 }
390 }
391}
392