ricecoder_tools/
webfetch.rs

1//! Webfetch tool for fetching web content
2//!
3//! Provides functionality to fetch and process web content from URLs with MCP integration.
4
5use crate::error::ToolError;
6use crate::result::ToolResult;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9use std::str::FromStr;
10use std::time::Instant;
11use tracing::{debug, warn};
12use url::Url;
13
14/// Maximum content size before truncation (50KB)
15const MAX_CONTENT_SIZE: usize = 50 * 1024;
16
17/// HTTP request timeout in seconds
18const REQUEST_TIMEOUT_SECS: u64 = 10;
19
20/// Input for webfetch operation
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct WebfetchInput {
23    /// URL to fetch
24    pub url: String,
25    /// Optional maximum content size in bytes
26    pub max_size: Option<usize>,
27}
28
29impl WebfetchInput {
30    /// Create a new webfetch input
31    pub fn new(url: impl Into<String>) -> Self {
32        Self {
33            url: url.into(),
34            max_size: None,
35        }
36    }
37
38    /// Set maximum content size
39    pub fn with_max_size(mut self, max_size: usize) -> Self {
40        self.max_size = Some(max_size);
41        self
42    }
43}
44
45/// Output for webfetch operation
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct WebfetchOutput {
48    /// Fetched content
49    pub content: String,
50    /// Whether content was truncated
51    pub truncated: bool,
52    /// Original content size in bytes
53    pub original_size: usize,
54    /// Actual returned content size in bytes
55    pub returned_size: usize,
56}
57
58impl WebfetchOutput {
59    /// Create a new webfetch output
60    pub fn new(content: String, original_size: usize) -> Self {
61        let returned_size = content.len();
62        let truncated = returned_size < original_size;
63
64        Self {
65            content,
66            truncated,
67            original_size,
68            returned_size,
69        }
70    }
71}
72
73/// Webfetch tool for fetching web content
74pub struct WebfetchTool {
75    client: reqwest::Client,
76}
77
78impl WebfetchTool {
79    /// Create a new webfetch tool
80    pub fn new() -> Result<Self, ToolError> {
81        let client = reqwest::Client::builder()
82            .timeout(std::time::Duration::from_secs(REQUEST_TIMEOUT_SECS))
83            .build()
84            .map_err(|e| {
85                ToolError::new("CLIENT_ERROR", "Failed to create HTTP client")
86                    .with_details(e.to_string())
87            })?;
88
89        Ok(Self { client })
90    }
91
92    /// Validate URL format and security
93    pub fn validate_url(url: &str) -> Result<Url, ToolError> {
94        // Parse URL
95        let parsed_url = Url::parse(url).map_err(|e| {
96            ToolError::new("INVALID_URL", "Invalid URL format")
97                .with_details(e.to_string())
98                .with_suggestion("Ensure URL is properly formatted (e.g., https://example.com)")
99        })?;
100
101        // Check scheme is http or https
102        match parsed_url.scheme() {
103            "http" | "https" => {}
104            _ => {
105                return Err(ToolError::new("INVALID_SCHEME", "Only http and https schemes are supported")
106                    .with_suggestion("Use http:// or https:// URLs"))
107            }
108        }
109
110        // Check for SSRF attacks - reject private IPs and localhost
111        if let Some(host) = parsed_url.host_str() {
112            // Reject localhost
113            if host == "localhost" || host == "127.0.0.1" || host == "::1" {
114                return Err(ToolError::new("SSRF_PREVENTION", "Localhost URLs are not allowed")
115                    .with_suggestion("Use a public URL instead"));
116            }
117
118            // Try to parse as IP address and check if private
119            if let Ok(ip) = IpAddr::from_str(host) {
120                let is_private = match ip {
121                    IpAddr::V4(ipv4) => ipv4.is_private() || ipv4.is_loopback(),
122                    IpAddr::V6(ipv6) => ipv6.is_loopback(),
123                };
124
125                if is_private {
126                    return Err(ToolError::new("SSRF_PREVENTION", "Private IP addresses are not allowed")
127                        .with_details(format!("IP: {}", ip))
128                        .with_suggestion("Use a public URL instead"));
129                }
130            }
131        }
132
133        Ok(parsed_url)
134    }
135
136    /// Fetch content from a URL with timeout enforcement
137    pub async fn fetch(&self, input: WebfetchInput) -> ToolResult<WebfetchOutput> {
138        let start = Instant::now();
139
140        // Validate URL
141        match Self::validate_url(&input.url) {
142            Ok(_) => {
143                debug!("URL validation passed: {}", input.url);
144            }
145            Err(e) => {
146                let duration_ms = start.elapsed().as_millis() as u64;
147                return ToolResult::err(e, duration_ms, "builtin");
148            }
149        }
150
151        // Determine max size
152        let max_size = input.max_size.unwrap_or(MAX_CONTENT_SIZE);
153
154        // Enforce timeout using tokio::time::timeout
155        let timeout_duration = std::time::Duration::from_secs(REQUEST_TIMEOUT_SECS);
156        let fetch_future = async {
157            // Fetch content
158            match self.client.get(&input.url).send().await {
159                Ok(response) => {
160                    // Check status
161                    if !response.status().is_success() {
162                        let error = ToolError::new(
163                            "HTTP_ERROR",
164                            format!("HTTP error: {}", response.status()),
165                        )
166                        .with_details(format!("Status code: {}", response.status().as_u16()))
167                        .with_suggestion("Check the URL and try again");
168                        return Err(error);
169                    }
170
171                    // Fetch body
172                    match response.bytes().await {
173                        Ok(bytes) => {
174                            let original_size = bytes.len();
175
176                            // Truncate if necessary
177                            let content = if original_size > max_size {
178                                warn!(
179                                    "Content truncated from {} to {} bytes",
180                                    original_size, max_size
181                                );
182                                String::from_utf8_lossy(&bytes[..max_size]).to_string()
183                            } else {
184                                String::from_utf8_lossy(&bytes).to_string()
185                            };
186
187                            let output = WebfetchOutput::new(content, original_size);
188                            Ok(output)
189                        }
190                        Err(e) => {
191                            let error = ToolError::from(e);
192                            Err(error)
193                        }
194                    }
195                }
196                Err(e) => {
197                    let error = ToolError::from(e);
198                    Err(error)
199                }
200            }
201        };
202
203        match tokio::time::timeout(timeout_duration, fetch_future).await {
204            Ok(Ok(output)) => {
205                let duration_ms = start.elapsed().as_millis() as u64;
206                ToolResult::ok(output, duration_ms, "builtin")
207            }
208            Ok(Err(error)) => {
209                let duration_ms = start.elapsed().as_millis() as u64;
210                ToolResult::err(error, duration_ms, "builtin")
211            }
212            Err(_) => {
213                let duration_ms = start.elapsed().as_millis() as u64;
214                let error = ToolError::new("TIMEOUT", "Webfetch operation exceeded 10 second timeout")
215                    .with_details(format!("URL: {}", input.url))
216                    .with_suggestion("Try again with a different URL or check your network connection");
217                ToolResult::err(error, duration_ms, "builtin")
218            }
219        }
220    }
221}
222
223impl Default for WebfetchTool {
224    fn default() -> Self {
225        Self::new().expect("Failed to create default WebfetchTool")
226    }
227}
228
229/// Webfetch tool with MCP integration
230pub struct WebfetchToolWithMcp {
231    builtin: WebfetchTool,
232    mcp_client: Option<ricecoder_mcp::MCPClient>,
233}
234
235impl WebfetchToolWithMcp {
236    /// Create a new webfetch tool with MCP integration
237    pub fn new(mcp_client: Option<ricecoder_mcp::MCPClient>) -> Result<Self, ToolError> {
238        let builtin = WebfetchTool::new()?;
239        Ok(Self { builtin, mcp_client })
240    }
241
242    /// Check if MCP server is available for webfetch
243    async fn is_mcp_available(&self) -> bool {
244        if let Some(client) = &self.mcp_client {
245            // Try to discover webfetch tools from MCP servers
246            if let Ok(servers) = client.discover_servers().await {
247                for server_id in servers {
248                    if let Ok(tools) = client.discover_tools(&server_id).await {
249                        if tools.iter().any(|t| t.id.contains("webfetch")) {
250                            debug!("MCP webfetch server available: {}", server_id);
251                            return true;
252                        }
253                    }
254                }
255            }
256        }
257        false
258    }
259
260    /// Fetch content with MCP fallback
261    pub async fn fetch(&self, input: WebfetchInput) -> ToolResult<WebfetchOutput> {
262        // Try MCP first if available
263        if self.is_mcp_available().await {
264            debug!("Attempting to use MCP webfetch provider");
265            // In a real implementation, we would call the MCP server here
266            // For now, we fall back to built-in
267        }
268
269        // Fall back to built-in implementation
270        debug!("Using built-in webfetch provider");
271        self.builtin.fetch(input).await
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_webfetch_input_creation() {
281        let input = WebfetchInput::new("https://example.com");
282        assert_eq!(input.url, "https://example.com");
283        assert!(input.max_size.is_none());
284    }
285
286    #[test]
287    fn test_webfetch_input_with_max_size() {
288        let input = WebfetchInput::new("https://example.com").with_max_size(1024);
289        assert_eq!(input.max_size, Some(1024));
290    }
291
292    #[test]
293    fn test_webfetch_output_creation() {
294        let output = WebfetchOutput::new("test content".to_string(), 12);
295        assert_eq!(output.content, "test content");
296        assert_eq!(output.original_size, 12);
297        assert_eq!(output.returned_size, 12);
298        assert!(!output.truncated);
299    }
300
301    #[test]
302    fn test_webfetch_output_truncation() {
303        let output = WebfetchOutput::new("test".to_string(), 100);
304        assert_eq!(output.content, "test");
305        assert_eq!(output.original_size, 100);
306        assert_eq!(output.returned_size, 4);
307        assert!(output.truncated);
308    }
309
310    #[test]
311    fn test_validate_url_valid_https() {
312        let result = WebfetchTool::validate_url("https://example.com");
313        assert!(result.is_ok());
314    }
315
316    #[test]
317    fn test_validate_url_valid_http() {
318        let result = WebfetchTool::validate_url("http://example.com");
319        assert!(result.is_ok());
320    }
321
322    #[test]
323    fn test_validate_url_invalid_format() {
324        let result = WebfetchTool::validate_url("not a url");
325        assert!(result.is_err());
326        if let Err(e) = result {
327            assert_eq!(e.code, "INVALID_URL");
328        }
329    }
330
331    #[test]
332    fn test_validate_url_invalid_scheme() {
333        let result = WebfetchTool::validate_url("ftp://example.com");
334        assert!(result.is_err());
335        if let Err(e) = result {
336            assert_eq!(e.code, "INVALID_SCHEME");
337        }
338    }
339
340    #[test]
341    fn test_validate_url_localhost_rejection() {
342        let result = WebfetchTool::validate_url("http://localhost:8080");
343        assert!(result.is_err());
344        if let Err(e) = result {
345            assert_eq!(e.code, "SSRF_PREVENTION");
346        }
347    }
348
349    #[test]
350    fn test_validate_url_127_0_0_1_rejection() {
351        let result = WebfetchTool::validate_url("http://127.0.0.1");
352        assert!(result.is_err());
353        if let Err(e) = result {
354            assert_eq!(e.code, "SSRF_PREVENTION");
355        }
356    }
357
358    #[test]
359    fn test_validate_url_private_ip_rejection() {
360        let result = WebfetchTool::validate_url("http://192.168.1.1");
361        assert!(result.is_err());
362        if let Err(e) = result {
363            assert_eq!(e.code, "SSRF_PREVENTION");
364        }
365    }
366
367    #[test]
368    fn test_webfetch_tool_creation() {
369        let tool = WebfetchTool::new();
370        assert!(tool.is_ok());
371    }
372
373    #[tokio::test]
374    async fn test_webfetch_timeout_enforcement() {
375        // This test verifies that timeout is enforced
376        // We use a slow/non-responsive endpoint to trigger timeout
377        let tool = WebfetchTool::new().unwrap();
378        let input = WebfetchInput::new("http://httpbin.org/delay/15"); // 15 second delay
379        
380        let result = tool.fetch(input).await;
381        
382        // Should fail (either timeout or HTTP error due to network conditions)
383        assert!(!result.success);
384        assert!(result.error.is_some());
385        if let Some(error) = result.error {
386            // Accept either TIMEOUT or HTTP_ERROR as both indicate the request failed
387            assert!(error.code == "TIMEOUT" || error.code == "HTTP_ERROR", 
388                    "Expected TIMEOUT or HTTP_ERROR, got: {}", error.code);
389        }
390    }
391}
392