Skip to main content

agent_sdk/web/
fetch.rs

1//! Link fetch tool implementation.
2
3use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result, bail};
6use serde_json::{Value, json};
7use std::time::Duration;
8
9use super::security::UrlValidator;
10
11/// Maximum content size to fetch (1MB).
12const MAX_CONTENT_SIZE: usize = 1024 * 1024;
13
14/// Default request timeout.
15const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
16
17/// Output format for fetched content.
18#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
19pub enum FetchFormat {
20    /// Plain text output (HTML tags removed).
21    #[default]
22    Text,
23    /// Markdown-formatted output.
24    Markdown,
25}
26
27impl FetchFormat {
28    fn from_str(s: &str) -> Option<Self> {
29        match s.to_lowercase().as_str() {
30            "text" => Some(Self::Text),
31            "markdown" | "md" => Some(Self::Markdown),
32            _ => None,
33        }
34    }
35}
36
37/// Link fetch tool for securely retrieving web page content.
38///
39/// This tool fetches web pages and converts them to text or markdown format.
40/// It includes SSRF protection to prevent access to internal resources.
41///
42/// # Example
43///
44/// ```ignore
45/// use agent_sdk::web::LinkFetchTool;
46///
47/// let tool = LinkFetchTool::new();
48///
49/// // Register with agent
50/// tools.register(tool);
51/// ```
52pub struct LinkFetchTool {
53    client: reqwest::Client,
54    validator: UrlValidator,
55    default_format: FetchFormat,
56}
57
58impl Default for LinkFetchTool {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl LinkFetchTool {
65    /// Create a new link fetch tool with default settings.
66    ///
67    /// # Panics
68    ///
69    /// Panics if the HTTP client cannot be built (should never happen with default settings).
70    #[must_use]
71    pub fn new() -> Self {
72        let client = reqwest::Client::builder()
73            .redirect(reqwest::redirect::Policy::none())
74            .timeout(DEFAULT_TIMEOUT)
75            .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)")
76            .build()
77            .expect("Failed to build HTTP client");
78
79        Self {
80            client,
81            validator: UrlValidator::new(),
82            default_format: FetchFormat::Text,
83        }
84    }
85
86    /// Create with a custom URL validator.
87    #[must_use]
88    pub fn with_validator(mut self, validator: UrlValidator) -> Self {
89        self.validator = validator;
90        self
91    }
92
93    /// Create with a custom HTTP client.
94    #[must_use]
95    pub fn with_client(mut self, client: reqwest::Client) -> Self {
96        self.client = client;
97        self
98    }
99
100    /// Set the default output format.
101    #[must_use]
102    pub const fn with_default_format(mut self, format: FetchFormat) -> Self {
103        self.default_format = format;
104        self
105    }
106
107    /// Fetch a URL and convert to the specified format.
108    ///
109    /// Manually follows redirects, validating each target URL through the
110    /// SSRF validator to prevent redirect-based SSRF attacks.
111    async fn fetch_url(&self, url_str: &str, format: FetchFormat) -> Result<String> {
112        // Validate initial URL before fetching
113        let mut url = self.validator.validate(url_str)?;
114        let max_redirects = self.validator.max_redirects();
115
116        let mut response = self
117            .client
118            .get(url.as_str())
119            .send()
120            .await
121            .context("Failed to fetch URL")?;
122
123        // Manually follow redirects with validation
124        let mut redirects = 0;
125        while response.status().is_redirection() {
126            redirects += 1;
127            if redirects > max_redirects {
128                bail!("Too many redirects ({redirects} > {max_redirects})");
129            }
130
131            let location = response
132                .headers()
133                .get(reqwest::header::LOCATION)
134                .context("Redirect response missing Location header")?
135                .to_str()
136                .context("Invalid Location header")?;
137
138            // Resolve relative redirect URLs against the current URL
139            let redirect_url_str = url
140                .join(location)
141                .map_or_else(|_| location.to_string(), |u| u.to_string());
142
143            // Validate the redirect target through the same SSRF checks
144            url = self.validator.validate(&redirect_url_str)?;
145
146            response = self
147                .client
148                .get(url.as_str())
149                .send()
150                .await
151                .context("Failed to follow redirect")?;
152        }
153
154        // Check status
155        if !response.status().is_success() {
156            bail!("HTTP error: {}", response.status());
157        }
158
159        // Check content length if available
160        if let Some(len) = response.content_length()
161            && len > MAX_CONTENT_SIZE as u64
162        {
163            bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
164        }
165
166        // Get content type to determine processing
167        let content_type = response
168            .headers()
169            .get(reqwest::header::CONTENT_TYPE)
170            .and_then(|v| v.to_str().ok())
171            .unwrap_or("text/html")
172            .to_string();
173
174        // Read body with size limit
175        let bytes = response
176            .bytes()
177            .await
178            .context("Failed to read response body")?;
179
180        if bytes.len() > MAX_CONTENT_SIZE {
181            bail!(
182                "Content too large: {} bytes (max {} bytes)",
183                bytes.len(),
184                MAX_CONTENT_SIZE
185            );
186        }
187
188        // Convert to string
189        let html = String::from_utf8_lossy(&bytes);
190
191        // Process based on content type and format
192        if content_type.contains("text/html") || content_type.contains("application/xhtml") {
193            Ok(convert_html(&html, format))
194        } else if content_type.contains("text/plain") {
195            Ok(html.into_owned())
196        } else {
197            // For other content types, just return as-is
198            Ok(html.into_owned())
199        }
200    }
201}
202
203/// Convert HTML to the specified format.
204fn convert_html(html: &str, format: FetchFormat) -> String {
205    let result = match format {
206        FetchFormat::Text => {
207            // Use html2text with default width
208            html2text::from_read(html.as_bytes(), 80)
209        }
210        FetchFormat::Markdown => {
211            // Use html2text with markdown-friendly settings
212            html2text::from_read(html.as_bytes(), 80)
213        }
214    };
215    result.unwrap_or_else(|_| html.to_string())
216}
217
218impl<Ctx> Tool<Ctx> for LinkFetchTool
219where
220    Ctx: Send + Sync + 'static,
221{
222    type Name = PrimitiveToolName;
223
224    fn name(&self) -> PrimitiveToolName {
225        PrimitiveToolName::LinkFetch
226    }
227
228    fn display_name(&self) -> &'static str {
229        "Fetch URL"
230    }
231
232    fn description(&self) -> &'static str {
233        "Fetch and read web page content. Returns the page content as text or markdown. \
234         Includes SSRF protection to prevent access to internal resources."
235    }
236
237    fn input_schema(&self) -> Value {
238        json!({
239            "type": "object",
240            "properties": {
241                "url": {
242                    "type": "string",
243                    "description": "The URL to fetch (must be HTTPS)"
244                },
245                "format": {
246                    "type": "string",
247                    "enum": ["text", "markdown"],
248                    "description": "Output format (default: text)"
249                }
250            },
251            "required": ["url"]
252        })
253    }
254
255    fn tier(&self) -> ToolTier {
256        // Link fetch is read-only, so Observe tier
257        ToolTier::Observe
258    }
259
260    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
261        let url = input
262            .get("url")
263            .and_then(Value::as_str)
264            .context("Missing 'url' parameter")?;
265
266        let format = input
267            .get("format")
268            .and_then(Value::as_str)
269            .and_then(FetchFormat::from_str)
270            .unwrap_or(self.default_format);
271
272        match self.fetch_url(url, format).await {
273            Ok(content) => Ok(ToolResult {
274                success: true,
275                output: content,
276                data: Some(json!({ "url": url, "format": format_name(format) })),
277                documents: Vec::new(),
278                duration_ms: None,
279            }),
280            Err(e) => Ok(ToolResult {
281                success: false,
282                output: format!("Failed to fetch URL: {e}"),
283                data: Some(json!({ "url": url, "error": e.to_string() })),
284                documents: Vec::new(),
285                duration_ms: None,
286            }),
287        }
288    }
289}
290
291/// Get the format name for JSON output.
292const fn format_name(format: FetchFormat) -> &'static str {
293    match format {
294        FetchFormat::Text => "text",
295        FetchFormat::Markdown => "markdown",
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_link_fetch_tool_metadata() {
305        let tool = LinkFetchTool::new();
306
307        assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::LinkFetch);
308        assert!(Tool::<()>::description(&tool).contains("Fetch"));
309        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
310    }
311
312    #[test]
313    fn test_link_fetch_tool_input_schema() {
314        let tool = LinkFetchTool::new();
315
316        let schema = Tool::<()>::input_schema(&tool);
317        assert_eq!(schema["type"], "object");
318        assert!(schema["properties"]["url"].is_object());
319        assert!(schema["properties"]["format"].is_object());
320        assert!(
321            schema["required"]
322                .as_array()
323                .is_some_and(|arr| arr.iter().any(|v| v == "url"))
324        );
325    }
326
327    #[test]
328    fn test_format_from_str() {
329        assert_eq!(FetchFormat::from_str("text"), Some(FetchFormat::Text));
330        assert_eq!(FetchFormat::from_str("TEXT"), Some(FetchFormat::Text));
331        assert_eq!(
332            FetchFormat::from_str("markdown"),
333            Some(FetchFormat::Markdown)
334        );
335        assert_eq!(FetchFormat::from_str("md"), Some(FetchFormat::Markdown));
336        assert_eq!(FetchFormat::from_str("invalid"), None);
337    }
338
339    #[test]
340    fn test_convert_html_text() {
341        let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
342        let result = convert_html(html, FetchFormat::Text);
343        assert!(result.contains("Title"));
344        assert!(result.contains("Paragraph"));
345    }
346
347    #[test]
348    fn test_default_format() {
349        let tool = LinkFetchTool::new();
350        assert_eq!(tool.default_format, FetchFormat::Text);
351
352        let tool = LinkFetchTool::new().with_default_format(FetchFormat::Markdown);
353        assert_eq!(tool.default_format, FetchFormat::Markdown);
354    }
355
356    #[tokio::test]
357    async fn test_link_fetch_blocked_url() {
358        let tool = LinkFetchTool::new();
359        let ctx = ToolContext::new(());
360        let input = json!({ "url": "http://localhost:8080" });
361
362        let result = Tool::<()>::execute(&tool, &ctx, input).await;
363        assert!(result.is_ok());
364
365        let tool_result = result.expect("Should succeed");
366        assert!(!tool_result.success);
367        assert!(
368            tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
369        );
370    }
371
372    #[tokio::test]
373    async fn test_link_fetch_missing_url() {
374        let tool = LinkFetchTool::new();
375        let ctx = ToolContext::new(());
376        let input = json!({});
377
378        let result = Tool::<()>::execute(&tool, &ctx, input).await;
379        assert!(result.is_err());
380        assert!(result.unwrap_err().to_string().contains("url"));
381    }
382
383    #[tokio::test]
384    async fn test_link_fetch_invalid_url() {
385        let tool = LinkFetchTool::new();
386        let ctx = ToolContext::new(());
387        let input = json!({ "url": "not-a-valid-url" });
388
389        let result = Tool::<()>::execute(&tool, &ctx, input).await;
390        assert!(result.is_ok());
391
392        let tool_result = result.expect("Should succeed");
393        assert!(!tool_result.success);
394        assert!(tool_result.output.contains("Invalid URL"));
395    }
396
397    #[test]
398    fn test_with_validator() {
399        let validator = UrlValidator::new().with_allow_http();
400        let _tool = LinkFetchTool::new().with_validator(validator);
401        // Just verify it compiles - validator is private
402    }
403
404    #[test]
405    fn test_format_name() {
406        assert_eq!(format_name(FetchFormat::Text), "text");
407        assert_eq!(format_name(FetchFormat::Markdown), "markdown");
408    }
409
410    #[test]
411    fn test_redirects_disabled_in_client() {
412        // Verify that the default client has redirects disabled
413        // (reqwest::Policy::none means no automatic redirect following)
414        let tool = LinkFetchTool::new();
415        // The client is private, but we can verify redirect behavior indirectly:
416        // A redirect response should NOT be automatically followed
417        assert_eq!(tool.validator.max_redirects(), 3);
418    }
419
420    #[tokio::test]
421    async fn test_redirect_to_private_ip_blocked() {
422        // Simulate: a redirect target pointing to a private IP should be blocked
423        // by the validator during manual redirect following.
424        let validator = UrlValidator::new().with_allow_http();
425
426        // Direct access to private IPs should be blocked
427        let result = validator.validate("http://169.254.169.254/latest/meta-data/");
428        assert!(result.is_err());
429        assert!(result.unwrap_err().to_string().contains("blocked"));
430
431        // Direct access to 10.x should be blocked
432        let result = validator.validate("http://10.0.0.1/internal");
433        assert!(result.is_err());
434    }
435
436    #[tokio::test]
437    async fn test_redirect_to_localhost_blocked() {
438        let validator = UrlValidator::new().with_allow_http();
439
440        // Redirect target pointing to localhost should be blocked
441        let result = validator.validate("http://127.0.0.1/admin");
442        assert!(result.is_err());
443    }
444}