Skip to main content

agent_sdk/web/
fetch.rs

1//! Link fetch tool implementation.
2
3use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result, bail};
6use serde_json::{Value, json};
7use std::time::Duration;
8
9use super::security::UrlValidator;
10
11/// Maximum content size to fetch (1MB).
12const MAX_CONTENT_SIZE: usize = 1024 * 1024;
13
14/// Default request timeout.
15const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
16
17/// Output format for fetched content.
18#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
19pub enum FetchFormat {
20    /// Plain text output (HTML tags removed).
21    #[default]
22    Text,
23    /// Markdown-formatted output.
24    Markdown,
25}
26
27impl FetchFormat {
28    fn from_str(s: &str) -> Option<Self> {
29        match s.to_lowercase().as_str() {
30            "text" => Some(Self::Text),
31            "markdown" | "md" => Some(Self::Markdown),
32            _ => None,
33        }
34    }
35}
36
37/// Link fetch tool for securely retrieving web page content.
38///
39/// This tool fetches web pages and converts them to text or markdown format.
40/// It includes SSRF protection to prevent access to internal resources.
41///
42/// # Example
43///
44/// ```ignore
45/// use agent_sdk::web::LinkFetchTool;
46///
47/// let tool = LinkFetchTool::new();
48///
49/// // Register with agent
50/// tools.register(tool);
51/// ```
52pub struct LinkFetchTool {
53    client: reqwest::Client,
54    validator: UrlValidator,
55    default_format: FetchFormat,
56}
57
58impl Default for LinkFetchTool {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl LinkFetchTool {
65    /// Create a new link fetch tool with default settings.
66    ///
67    /// # Panics
68    ///
69    /// Panics if the HTTP client cannot be built (should never happen with default settings).
70    #[must_use]
71    pub fn new() -> Self {
72        let client = reqwest::Client::builder()
73            .timeout(DEFAULT_TIMEOUT)
74            .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)")
75            .build()
76            .expect("Failed to build HTTP client");
77
78        Self {
79            client,
80            validator: UrlValidator::new(),
81            default_format: FetchFormat::Text,
82        }
83    }
84
85    /// Create with a custom URL validator.
86    #[must_use]
87    pub fn with_validator(mut self, validator: UrlValidator) -> Self {
88        self.validator = validator;
89        self
90    }
91
92    /// Create with a custom HTTP client.
93    #[must_use]
94    pub fn with_client(mut self, client: reqwest::Client) -> Self {
95        self.client = client;
96        self
97    }
98
99    /// Set the default output format.
100    #[must_use]
101    pub const fn with_default_format(mut self, format: FetchFormat) -> Self {
102        self.default_format = format;
103        self
104    }
105
106    /// Fetch a URL and convert to the specified format.
107    async fn fetch_url(&self, url_str: &str, format: FetchFormat) -> Result<String> {
108        // Validate URL before fetching
109        let url = self.validator.validate(url_str)?;
110
111        // Build request with redirect policy
112        let response = self
113            .client
114            .get(url.as_str())
115            .send()
116            .await
117            .context("Failed to fetch URL")?;
118
119        // Check status
120        if !response.status().is_success() {
121            bail!("HTTP error: {}", response.status());
122        }
123
124        // Check content length if available
125        if let Some(len) = response.content_length()
126            && len > MAX_CONTENT_SIZE as u64
127        {
128            bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
129        }
130
131        // Get content type to determine processing
132        let content_type = response
133            .headers()
134            .get(reqwest::header::CONTENT_TYPE)
135            .and_then(|v| v.to_str().ok())
136            .unwrap_or("text/html")
137            .to_string();
138
139        // Read body with size limit
140        let bytes = response
141            .bytes()
142            .await
143            .context("Failed to read response body")?;
144
145        if bytes.len() > MAX_CONTENT_SIZE {
146            bail!(
147                "Content too large: {} bytes (max {} bytes)",
148                bytes.len(),
149                MAX_CONTENT_SIZE
150            );
151        }
152
153        // Convert to string
154        let html = String::from_utf8_lossy(&bytes);
155
156        // Process based on content type and format
157        if content_type.contains("text/html") || content_type.contains("application/xhtml") {
158            Ok(convert_html(&html, format))
159        } else if content_type.contains("text/plain") {
160            Ok(html.into_owned())
161        } else {
162            // For other content types, just return as-is
163            Ok(html.into_owned())
164        }
165    }
166}
167
168/// Convert HTML to the specified format.
169fn convert_html(html: &str, format: FetchFormat) -> String {
170    let result = match format {
171        FetchFormat::Text => {
172            // Use html2text with default width
173            html2text::from_read(html.as_bytes(), 80)
174        }
175        FetchFormat::Markdown => {
176            // Use html2text with markdown-friendly settings
177            html2text::from_read(html.as_bytes(), 80)
178        }
179    };
180    result.unwrap_or_else(|_| html.to_string())
181}
182
183impl<Ctx> Tool<Ctx> for LinkFetchTool
184where
185    Ctx: Send + Sync + 'static,
186{
187    type Name = PrimitiveToolName;
188
189    fn name(&self) -> PrimitiveToolName {
190        PrimitiveToolName::LinkFetch
191    }
192
193    fn display_name(&self) -> &'static str {
194        "Fetch URL"
195    }
196
197    fn description(&self) -> &'static str {
198        "Fetch and read web page content. Returns the page content as text or markdown. \
199         Includes SSRF protection to prevent access to internal resources."
200    }
201
202    fn input_schema(&self) -> Value {
203        json!({
204            "type": "object",
205            "properties": {
206                "url": {
207                    "type": "string",
208                    "description": "The URL to fetch (must be HTTPS)"
209                },
210                "format": {
211                    "type": "string",
212                    "enum": ["text", "markdown"],
213                    "description": "Output format (default: text)"
214                }
215            },
216            "required": ["url"]
217        })
218    }
219
220    fn tier(&self) -> ToolTier {
221        // Link fetch is read-only, so Observe tier
222        ToolTier::Observe
223    }
224
225    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
226        let url = input
227            .get("url")
228            .and_then(Value::as_str)
229            .context("Missing 'url' parameter")?;
230
231        let format = input
232            .get("format")
233            .and_then(Value::as_str)
234            .and_then(FetchFormat::from_str)
235            .unwrap_or(self.default_format);
236
237        match self.fetch_url(url, format).await {
238            Ok(content) => Ok(ToolResult {
239                success: true,
240                output: content,
241                data: Some(json!({ "url": url, "format": format_name(format) })),
242                duration_ms: None,
243            }),
244            Err(e) => Ok(ToolResult {
245                success: false,
246                output: format!("Failed to fetch URL: {e}"),
247                data: Some(json!({ "url": url, "error": e.to_string() })),
248                duration_ms: None,
249            }),
250        }
251    }
252}
253
254/// Get the format name for JSON output.
255const fn format_name(format: FetchFormat) -> &'static str {
256    match format {
257        FetchFormat::Text => "text",
258        FetchFormat::Markdown => "markdown",
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_link_fetch_tool_metadata() {
268        let tool = LinkFetchTool::new();
269
270        assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::LinkFetch);
271        assert!(Tool::<()>::description(&tool).contains("Fetch"));
272        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
273    }
274
275    #[test]
276    fn test_link_fetch_tool_input_schema() {
277        let tool = LinkFetchTool::new();
278
279        let schema = Tool::<()>::input_schema(&tool);
280        assert_eq!(schema["type"], "object");
281        assert!(schema["properties"]["url"].is_object());
282        assert!(schema["properties"]["format"].is_object());
283        assert!(
284            schema["required"]
285                .as_array()
286                .is_some_and(|arr| arr.iter().any(|v| v == "url"))
287        );
288    }
289
290    #[test]
291    fn test_format_from_str() {
292        assert_eq!(FetchFormat::from_str("text"), Some(FetchFormat::Text));
293        assert_eq!(FetchFormat::from_str("TEXT"), Some(FetchFormat::Text));
294        assert_eq!(
295            FetchFormat::from_str("markdown"),
296            Some(FetchFormat::Markdown)
297        );
298        assert_eq!(FetchFormat::from_str("md"), Some(FetchFormat::Markdown));
299        assert_eq!(FetchFormat::from_str("invalid"), None);
300    }
301
302    #[test]
303    fn test_convert_html_text() {
304        let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
305        let result = convert_html(html, FetchFormat::Text);
306        assert!(result.contains("Title"));
307        assert!(result.contains("Paragraph"));
308    }
309
310    #[test]
311    fn test_default_format() {
312        let tool = LinkFetchTool::new();
313        assert_eq!(tool.default_format, FetchFormat::Text);
314
315        let tool = LinkFetchTool::new().with_default_format(FetchFormat::Markdown);
316        assert_eq!(tool.default_format, FetchFormat::Markdown);
317    }
318
319    #[tokio::test]
320    async fn test_link_fetch_blocked_url() {
321        let tool = LinkFetchTool::new();
322        let ctx = ToolContext::new(());
323        let input = json!({ "url": "http://localhost:8080" });
324
325        let result = Tool::<()>::execute(&tool, &ctx, input).await;
326        assert!(result.is_ok());
327
328        let tool_result = result.expect("Should succeed");
329        assert!(!tool_result.success);
330        assert!(
331            tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
332        );
333    }
334
335    #[tokio::test]
336    async fn test_link_fetch_missing_url() {
337        let tool = LinkFetchTool::new();
338        let ctx = ToolContext::new(());
339        let input = json!({});
340
341        let result = Tool::<()>::execute(&tool, &ctx, input).await;
342        assert!(result.is_err());
343        assert!(result.unwrap_err().to_string().contains("url"));
344    }
345
346    #[tokio::test]
347    async fn test_link_fetch_invalid_url() {
348        let tool = LinkFetchTool::new();
349        let ctx = ToolContext::new(());
350        let input = json!({ "url": "not-a-valid-url" });
351
352        let result = Tool::<()>::execute(&tool, &ctx, input).await;
353        assert!(result.is_ok());
354
355        let tool_result = result.expect("Should succeed");
356        assert!(!tool_result.success);
357        assert!(tool_result.output.contains("Invalid URL"));
358    }
359
360    #[test]
361    fn test_with_validator() {
362        let validator = UrlValidator::new().with_allow_http();
363        let _tool = LinkFetchTool::new().with_validator(validator);
364        // Just verify it compiles - validator is private
365    }
366
367    #[test]
368    fn test_format_name() {
369        assert_eq!(format_name(FetchFormat::Text), "text");
370        assert_eq!(format_name(FetchFormat::Markdown), "markdown");
371    }
372}