agent_sdk/web/
fetch.rs

1//! Link fetch tool implementation.
2
3use crate::tools::{Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10use super::security::UrlValidator;
11
12/// Maximum content size to fetch (1MB).
13const MAX_CONTENT_SIZE: usize = 1024 * 1024;
14
15/// Default request timeout.
16const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
17
18/// Output format for fetched content.
19#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
20pub enum FetchFormat {
21    /// Plain text output (HTML tags removed).
22    #[default]
23    Text,
24    /// Markdown-formatted output.
25    Markdown,
26}
27
28impl FetchFormat {
29    fn from_str(s: &str) -> Option<Self> {
30        match s.to_lowercase().as_str() {
31            "text" => Some(Self::Text),
32            "markdown" | "md" => Some(Self::Markdown),
33            _ => None,
34        }
35    }
36}
37
38/// Link fetch tool for securely retrieving web page content.
39///
40/// This tool fetches web pages and converts them to text or markdown format.
41/// It includes SSRF protection to prevent access to internal resources.
42///
43/// # Example
44///
45/// ```ignore
46/// use agent_sdk::web::LinkFetchTool;
47///
48/// let tool = LinkFetchTool::new();
49///
50/// // Register with agent
51/// tools.register(tool);
52/// ```
53pub struct LinkFetchTool {
54    client: reqwest::Client,
55    validator: UrlValidator,
56    default_format: FetchFormat,
57}
58
59impl Default for LinkFetchTool {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl LinkFetchTool {
66    /// Create a new link fetch tool with default settings.
67    ///
68    /// # Panics
69    ///
70    /// Panics if the HTTP client cannot be built (should never happen with default settings).
71    #[must_use]
72    pub fn new() -> Self {
73        let client = reqwest::Client::builder()
74            .timeout(DEFAULT_TIMEOUT)
75            .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)")
76            .build()
77            .expect("Failed to build HTTP client");
78
79        Self {
80            client,
81            validator: UrlValidator::new(),
82            default_format: FetchFormat::Text,
83        }
84    }
85
86    /// Create with a custom URL validator.
87    #[must_use]
88    pub fn with_validator(mut self, validator: UrlValidator) -> Self {
89        self.validator = validator;
90        self
91    }
92
93    /// Create with a custom HTTP client.
94    #[must_use]
95    pub fn with_client(mut self, client: reqwest::Client) -> Self {
96        self.client = client;
97        self
98    }
99
100    /// Set the default output format.
101    #[must_use]
102    pub const fn with_default_format(mut self, format: FetchFormat) -> Self {
103        self.default_format = format;
104        self
105    }
106
107    /// Fetch a URL and convert to the specified format.
108    async fn fetch_url(&self, url_str: &str, format: FetchFormat) -> Result<String> {
109        // Validate URL before fetching
110        let url = self.validator.validate(url_str)?;
111
112        // Build request with redirect policy
113        let response = self
114            .client
115            .get(url.as_str())
116            .send()
117            .await
118            .context("Failed to fetch URL")?;
119
120        // Check status
121        if !response.status().is_success() {
122            bail!("HTTP error: {}", response.status());
123        }
124
125        // Check content length if available
126        if let Some(len) = response.content_length()
127            && len > MAX_CONTENT_SIZE as u64
128        {
129            bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
130        }
131
132        // Get content type to determine processing
133        let content_type = response
134            .headers()
135            .get(reqwest::header::CONTENT_TYPE)
136            .and_then(|v| v.to_str().ok())
137            .unwrap_or("text/html")
138            .to_string();
139
140        // Read body with size limit
141        let bytes = response
142            .bytes()
143            .await
144            .context("Failed to read response body")?;
145
146        if bytes.len() > MAX_CONTENT_SIZE {
147            bail!(
148                "Content too large: {} bytes (max {} bytes)",
149                bytes.len(),
150                MAX_CONTENT_SIZE
151            );
152        }
153
154        // Convert to string
155        let html = String::from_utf8_lossy(&bytes);
156
157        // Process based on content type and format
158        if content_type.contains("text/html") || content_type.contains("application/xhtml") {
159            Ok(convert_html(&html, format))
160        } else if content_type.contains("text/plain") {
161            Ok(html.into_owned())
162        } else {
163            // For other content types, just return as-is
164            Ok(html.into_owned())
165        }
166    }
167}
168
169/// Convert HTML to the specified format.
170fn convert_html(html: &str, format: FetchFormat) -> String {
171    let result = match format {
172        FetchFormat::Text => {
173            // Use html2text with default width
174            html2text::from_read(html.as_bytes(), 80)
175        }
176        FetchFormat::Markdown => {
177            // Use html2text with markdown-friendly settings
178            html2text::from_read(html.as_bytes(), 80)
179        }
180    };
181    result.unwrap_or_else(|_| html.to_string())
182}
183
184#[async_trait]
185impl<Ctx> Tool<Ctx> for LinkFetchTool
186where
187    Ctx: Send + Sync + 'static,
188{
189    fn name(&self) -> &'static str {
190        "link_fetch"
191    }
192
193    fn description(&self) -> &'static str {
194        "Fetch and read web page content. Returns the page content as text or markdown. \
195         Includes SSRF protection to prevent access to internal resources."
196    }
197
198    fn input_schema(&self) -> Value {
199        json!({
200            "type": "object",
201            "properties": {
202                "url": {
203                    "type": "string",
204                    "description": "The URL to fetch (must be HTTPS)"
205                },
206                "format": {
207                    "type": "string",
208                    "enum": ["text", "markdown"],
209                    "description": "Output format (default: text)"
210                }
211            },
212            "required": ["url"]
213        })
214    }
215
216    fn tier(&self) -> ToolTier {
217        // Link fetch is read-only, so Observe tier
218        ToolTier::Observe
219    }
220
221    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
222        let url = input
223            .get("url")
224            .and_then(Value::as_str)
225            .context("Missing 'url' parameter")?;
226
227        let format = input
228            .get("format")
229            .and_then(Value::as_str)
230            .and_then(FetchFormat::from_str)
231            .unwrap_or(self.default_format);
232
233        match self.fetch_url(url, format).await {
234            Ok(content) => Ok(ToolResult {
235                success: true,
236                output: content,
237                data: Some(json!({ "url": url, "format": format_name(format) })),
238                duration_ms: None,
239            }),
240            Err(e) => Ok(ToolResult {
241                success: false,
242                output: format!("Failed to fetch URL: {e}"),
243                data: Some(json!({ "url": url, "error": e.to_string() })),
244                duration_ms: None,
245            }),
246        }
247    }
248}
249
250/// Get the format name for JSON output.
251const fn format_name(format: FetchFormat) -> &'static str {
252    match format {
253        FetchFormat::Text => "text",
254        FetchFormat::Markdown => "markdown",
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_link_fetch_tool_metadata() {
264        let tool = LinkFetchTool::new();
265
266        assert_eq!(Tool::<()>::name(&tool), "link_fetch");
267        assert!(Tool::<()>::description(&tool).contains("Fetch"));
268        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
269    }
270
271    #[test]
272    fn test_link_fetch_tool_input_schema() {
273        let tool = LinkFetchTool::new();
274
275        let schema = Tool::<()>::input_schema(&tool);
276        assert_eq!(schema["type"], "object");
277        assert!(schema["properties"]["url"].is_object());
278        assert!(schema["properties"]["format"].is_object());
279        assert!(
280            schema["required"]
281                .as_array()
282                .is_some_and(|arr| arr.iter().any(|v| v == "url"))
283        );
284    }
285
286    #[test]
287    fn test_format_from_str() {
288        assert_eq!(FetchFormat::from_str("text"), Some(FetchFormat::Text));
289        assert_eq!(FetchFormat::from_str("TEXT"), Some(FetchFormat::Text));
290        assert_eq!(
291            FetchFormat::from_str("markdown"),
292            Some(FetchFormat::Markdown)
293        );
294        assert_eq!(FetchFormat::from_str("md"), Some(FetchFormat::Markdown));
295        assert_eq!(FetchFormat::from_str("invalid"), None);
296    }
297
298    #[test]
299    fn test_convert_html_text() {
300        let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
301        let result = convert_html(html, FetchFormat::Text);
302        assert!(result.contains("Title"));
303        assert!(result.contains("Paragraph"));
304    }
305
306    #[test]
307    fn test_default_format() {
308        let tool = LinkFetchTool::new();
309        assert_eq!(tool.default_format, FetchFormat::Text);
310
311        let tool = LinkFetchTool::new().with_default_format(FetchFormat::Markdown);
312        assert_eq!(tool.default_format, FetchFormat::Markdown);
313    }
314
315    #[tokio::test]
316    async fn test_link_fetch_blocked_url() {
317        let tool = LinkFetchTool::new();
318        let ctx = ToolContext::new(());
319        let input = json!({ "url": "http://localhost:8080" });
320
321        let result = Tool::<()>::execute(&tool, &ctx, input).await;
322        assert!(result.is_ok());
323
324        let tool_result = result.expect("Should succeed");
325        assert!(!tool_result.success);
326        assert!(
327            tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
328        );
329    }
330
331    #[tokio::test]
332    async fn test_link_fetch_missing_url() {
333        let tool = LinkFetchTool::new();
334        let ctx = ToolContext::new(());
335        let input = json!({});
336
337        let result = Tool::<()>::execute(&tool, &ctx, input).await;
338        assert!(result.is_err());
339        assert!(result.unwrap_err().to_string().contains("url"));
340    }
341
342    #[tokio::test]
343    async fn test_link_fetch_invalid_url() {
344        let tool = LinkFetchTool::new();
345        let ctx = ToolContext::new(());
346        let input = json!({ "url": "not-a-valid-url" });
347
348        let result = Tool::<()>::execute(&tool, &ctx, input).await;
349        assert!(result.is_ok());
350
351        let tool_result = result.expect("Should succeed");
352        assert!(!tool_result.success);
353        assert!(tool_result.output.contains("Invalid URL"));
354    }
355
356    #[test]
357    fn test_with_validator() {
358        let validator = UrlValidator::new().with_allow_http();
359        let _tool = LinkFetchTool::new().with_validator(validator);
360        // Just verify it compiles - validator is private
361    }
362
363    #[test]
364    fn test_format_name() {
365        assert_eq!(format_name(FetchFormat::Text), "text");
366        assert_eq!(format_name(FetchFormat::Markdown), "markdown");
367    }
368}