Skip to main content

agent_sdk/web/
fetch.rs

1//! Link fetch tool implementation.
2
3use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result, bail};
6use serde_json::{Value, json};
7use std::time::Duration;
8
9use super::security::UrlValidator;
10
11/// Maximum content size to fetch (1MB).
12const MAX_CONTENT_SIZE: usize = 1024 * 1024;
13
14/// Default request timeout.
15const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
16
17/// Output format for fetched content.
18///
19/// Only plain text is currently supported. A `Markdown` variant previously
20/// existed but produced byte-identical output to `Text` (no real markdown
21/// conversion was implemented), so it was removed rather than advertise a
22/// distinction to the model that did not exist.
23#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
24pub enum FetchFormat {
25    /// Plain text output (HTML tags removed).
26    #[default]
27    Text,
28}
29
30/// Link fetch tool for securely retrieving web page content.
31///
32/// This tool fetches web pages and converts them to text or markdown format.
33/// It includes SSRF protection to prevent access to internal resources.
34///
35/// # Example
36///
37/// ```ignore
38/// use agent_sdk::web::LinkFetchTool;
39///
40/// let tool = LinkFetchTool::new();
41///
42/// // Register with agent
43/// tools.register(tool);
44/// ```
45pub struct LinkFetchTool {
46    /// Optional caller-supplied HTTP client.
47    ///
48    /// When `None` (the default), a fresh client is built per request with the
49    /// validated IP addresses pinned via [`reqwest::ClientBuilder::resolve_to_addrs`]
50    /// so the connection targets exactly the addresses that passed SSRF
51    /// validation (closing the DNS-rebinding window). When a custom client is
52    /// supplied via [`LinkFetchTool::with_client`], it is used as-is and the
53    /// caller is responsible for its redirect/SSRF policy.
54    client: Option<reqwest::Client>,
55    validator: UrlValidator,
56}
57
58impl Default for LinkFetchTool {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl LinkFetchTool {
65    /// Create a new link fetch tool with default settings.
66    ///
67    /// Does not build an HTTP client eagerly: the default client is constructed
68    /// per request (with the validated IPs pinned), so this constructor cannot
69    /// fail or panic.
70    #[must_use]
71    pub fn new() -> Self {
72        Self {
73            client: None,
74            validator: UrlValidator::new(),
75        }
76    }
77
78    /// Create with a custom URL validator.
79    #[must_use]
80    pub fn with_validator(mut self, validator: UrlValidator) -> Self {
81        self.validator = validator;
82        self
83    }
84
85    /// Create with a custom HTTP client.
86    ///
87    /// The supplied client is used verbatim; per-request IP pinning is *not*
88    /// applied, so the caller takes responsibility for redirect and SSRF
89    /// policy on that client.
90    #[must_use]
91    pub fn with_client(mut self, client: reqwest::Client) -> Self {
92        self.client = Some(client);
93        self
94    }
95
96    /// Build the HTTP client for a single request.
97    ///
98    /// Returns the caller-supplied client if one was configured; otherwise
99    /// builds a default client with `host` pinned to the vetted `addrs` so the
100    /// connection cannot be rebound to a different (blocked) address after
101    /// validation.
102    fn build_client(
103        &self,
104        host: Option<&str>,
105        addrs: &[std::net::SocketAddr],
106    ) -> Result<reqwest::Client> {
107        if let Some(client) = &self.client {
108            return Ok(client.clone());
109        }
110
111        let mut builder = reqwest::Client::builder()
112            .redirect(reqwest::redirect::Policy::none())
113            .timeout(DEFAULT_TIMEOUT)
114            .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)");
115
116        if let Some(host) = host
117            && !addrs.is_empty()
118        {
119            builder = builder.resolve_to_addrs(host, addrs);
120        }
121
122        builder.build().context("Failed to build HTTP client")
123    }
124
125    /// Fetch a URL and convert it to plain text.
126    ///
127    /// Manually follows redirects, validating each target URL through the
128    /// SSRF validator (and pinning its resolved IPs) to prevent redirect-based
129    /// and DNS-rebinding SSRF attacks.
130    async fn fetch_url(&self, url_str: &str) -> Result<String> {
131        // Validate initial URL before fetching, capturing the vetted addresses.
132        let mut validated = self.validator.validate(url_str).await?;
133        let max_redirects = self.validator.max_redirects();
134
135        let client = self.build_client(validated.url.host_str(), &validated.addresses)?;
136        let mut response = client
137            .get(validated.url.as_str())
138            .send()
139            .await
140            .context("Failed to fetch URL")?;
141
142        // Manually follow redirects with validation
143        let mut redirects = 0;
144        while response.status().is_redirection() {
145            redirects += 1;
146            if redirects > max_redirects {
147                bail!("Too many redirects ({redirects} > {max_redirects})");
148            }
149
150            let location = response
151                .headers()
152                .get(reqwest::header::LOCATION)
153                .context("Redirect response missing Location header")?
154                .to_str()
155                .context("Invalid Location header")?;
156
157            // Resolve relative redirect URLs against the current URL
158            let redirect_url_str = validated
159                .url
160                .join(location)
161                .map_or_else(|_| location.to_string(), |u| u.to_string());
162
163            // Validate the redirect target through the same SSRF checks and
164            // pin its freshly-vetted addresses.
165            validated = self.validator.validate(&redirect_url_str).await?;
166
167            let client = self.build_client(validated.url.host_str(), &validated.addresses)?;
168            response = client
169                .get(validated.url.as_str())
170                .send()
171                .await
172                .context("Failed to follow redirect")?;
173        }
174
175        // Check status
176        if !response.status().is_success() {
177            bail!("HTTP error: {}", response.status());
178        }
179
180        // Reject early if the advertised length already exceeds the cap.
181        if let Some(len) = response.content_length()
182            && len > MAX_CONTENT_SIZE as u64
183        {
184            bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
185        }
186
187        // Get content type to determine processing
188        let content_type = response
189            .headers()
190            .get(reqwest::header::CONTENT_TYPE)
191            .and_then(|v| v.to_str().ok())
192            .unwrap_or("text/html")
193            .to_string();
194
195        // Stream the body, bailing as soon as the cumulative size exceeds the
196        // cap. This bounds peak memory at ~MAX_CONTENT_SIZE regardless of
197        // whether the server sends a Content-Length header (chunked/streaming
198        // responses would otherwise allow unbounded allocation).
199        let bytes = read_capped_body(&mut response, MAX_CONTENT_SIZE).await?;
200
201        // Convert to string
202        let html = String::from_utf8_lossy(&bytes);
203
204        // Process based on content type
205        if content_type.contains("text/html") || content_type.contains("application/xhtml") {
206            Ok(convert_html(&html))
207        } else if content_type.contains("text/plain") {
208            Ok(html.into_owned())
209        } else {
210            // For other content types, just return as-is
211            Ok(html.into_owned())
212        }
213    }
214}
215
216/// Convert HTML to plain text.
217fn convert_html(html: &str) -> String {
218    html2text::from_read(html.as_bytes(), 80).unwrap_or_else(|_| html.to_string())
219}
220
221/// Read a response body into memory, bailing as soon as the cumulative size
222/// exceeds `max`.
223///
224/// Streams via [`reqwest::Response::chunk`] so peak memory stays bounded at
225/// ~`max` even for chunked/streaming responses that carry no `Content-Length`.
226async fn read_capped_body(response: &mut reqwest::Response, max: usize) -> Result<Vec<u8>> {
227    let mut bytes: Vec<u8> = Vec::new();
228    while let Some(chunk) = response
229        .chunk()
230        .await
231        .context("Failed to read response body")?
232    {
233        if bytes.len() + chunk.len() > max {
234            bail!("Content too large: exceeds {max} bytes");
235        }
236        bytes.extend_from_slice(&chunk);
237    }
238    Ok(bytes)
239}
240
241impl<Ctx> Tool<Ctx> for LinkFetchTool
242where
243    Ctx: Send + Sync + 'static,
244{
245    type Name = PrimitiveToolName;
246
247    fn name(&self) -> PrimitiveToolName {
248        PrimitiveToolName::LinkFetch
249    }
250
251    fn display_name(&self) -> &'static str {
252        "Fetch URL"
253    }
254
255    fn description(&self) -> &'static str {
256        "Fetch and read web page content. Returns the page content as text or markdown. \
257         Includes SSRF protection to prevent access to internal resources."
258    }
259
260    fn input_schema(&self) -> Value {
261        json!({
262            "type": "object",
263            "properties": {
264                "url": {
265                    "type": "string",
266                    "description": "The URL to fetch (must be HTTPS)"
267                }
268            },
269            "required": ["url"]
270        })
271    }
272
273    fn tier(&self) -> ToolTier {
274        // Link fetch is read-only, so Observe tier
275        ToolTier::Observe
276    }
277
278    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
279        let url = input
280            .get("url")
281            .and_then(Value::as_str)
282            .context("Missing 'url' parameter")?;
283
284        match self.fetch_url(url).await {
285            Ok(content) => Ok(ToolResult {
286                success: true,
287                output: content,
288                data: Some(json!({ "url": url })),
289                documents: Vec::new(),
290                duration_ms: None,
291            }),
292            Err(e) => Ok(ToolResult {
293                success: false,
294                output: format!("Failed to fetch URL: {e}"),
295                data: Some(json!({ "url": url, "error": e.to_string() })),
296                documents: Vec::new(),
297                duration_ms: None,
298            }),
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_link_fetch_tool_metadata() {
309        let tool = LinkFetchTool::new();
310
311        assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::LinkFetch);
312        assert!(Tool::<()>::description(&tool).contains("Fetch"));
313        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
314    }
315
316    #[test]
317    fn test_link_fetch_tool_input_schema() {
318        let tool = LinkFetchTool::new();
319
320        let schema = Tool::<()>::input_schema(&tool);
321        assert_eq!(schema["type"], "object");
322        assert!(schema["properties"]["url"].is_object());
323        // The dead `format`/markdown option was removed from the schema.
324        assert!(schema["properties"]["format"].is_null());
325        assert!(
326            schema["required"]
327                .as_array()
328                .is_some_and(|arr| arr.iter().any(|v| v == "url"))
329        );
330    }
331
332    #[test]
333    fn test_convert_html_text() {
334        let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
335        let result = convert_html(html);
336        assert!(result.contains("Title"));
337        assert!(result.contains("Paragraph"));
338    }
339
340    #[tokio::test]
341    async fn test_link_fetch_blocked_url() {
342        let tool = LinkFetchTool::new();
343        let ctx = ToolContext::new(());
344        let input = json!({ "url": "http://localhost:8080" });
345
346        let result = Tool::<()>::execute(&tool, &ctx, input).await;
347        assert!(result.is_ok());
348
349        let tool_result = result.expect("Should succeed");
350        assert!(!tool_result.success);
351        assert!(
352            tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
353        );
354    }
355
356    #[tokio::test]
357    async fn test_link_fetch_missing_url() {
358        let tool = LinkFetchTool::new();
359        let ctx = ToolContext::new(());
360        let input = json!({});
361
362        let result = Tool::<()>::execute(&tool, &ctx, input).await;
363        assert!(result.is_err());
364        assert!(result.unwrap_err().to_string().contains("url"));
365    }
366
367    #[tokio::test]
368    async fn test_link_fetch_invalid_url() {
369        let tool = LinkFetchTool::new();
370        let ctx = ToolContext::new(());
371        let input = json!({ "url": "not-a-valid-url" });
372
373        let result = Tool::<()>::execute(&tool, &ctx, input).await;
374        assert!(result.is_ok());
375
376        let tool_result = result.expect("Should succeed");
377        assert!(!tool_result.success);
378        assert!(tool_result.output.contains("Invalid URL"));
379    }
380
381    #[test]
382    fn test_with_validator() {
383        let validator = UrlValidator::new().with_allow_http();
384        let _tool = LinkFetchTool::new().with_validator(validator);
385        // Just verify it compiles - validator is private
386    }
387
388    #[test]
389    fn test_redirects_disabled_in_client() {
390        // Verify that the default client has redirects disabled
391        // (reqwest::Policy::none means no automatic redirect following)
392        let tool = LinkFetchTool::new();
393        // The client is private, but we can verify redirect behavior indirectly:
394        // A redirect response should NOT be automatically followed
395        assert_eq!(tool.validator.max_redirects(), 3);
396    }
397
398    #[tokio::test]
399    async fn test_redirect_to_private_ip_blocked() {
400        // Simulate: a redirect target pointing to a private IP should be blocked
401        // by the validator during manual redirect following.
402        let validator = UrlValidator::new().with_allow_http();
403
404        // Direct access to private IPs should be blocked
405        let result = validator
406            .validate("http://169.254.169.254/latest/meta-data/")
407            .await;
408        assert!(result.is_err());
409        assert!(result.unwrap_err().to_string().contains("blocked"));
410
411        // Direct access to 10.x should be blocked
412        let result = validator.validate("http://10.0.0.1/internal").await;
413        assert!(result.is_err());
414    }
415
416    #[tokio::test]
417    async fn test_redirect_to_localhost_blocked() {
418        let validator = UrlValidator::new().with_allow_http();
419
420        // Redirect target pointing to localhost should be blocked
421        let result = validator.validate("http://127.0.0.1/admin").await;
422        assert!(result.is_err());
423    }
424
425    /// Regression test for the body-size cap (findings 4 & 14). A server that
426    /// sends a body larger than the cap with NO `Content-Length` header
427    /// (so the header pre-check cannot apply) must be rejected while streaming,
428    /// before the whole body is buffered. Exercised against a local loopback
429    /// server so the test is deterministic and needs no network.
430    #[tokio::test]
431    async fn test_read_capped_body_rejects_oversized_stream() -> Result<()> {
432        use tokio::io::{AsyncReadExt, AsyncWriteExt};
433        use tokio::net::TcpListener;
434
435        let listener = TcpListener::bind("127.0.0.1:0").await?;
436        let addr = listener.local_addr()?;
437
438        let server = tokio::spawn(async move {
439            if let Ok((mut sock, _)) = listener.accept().await {
440                let mut buf = [0u8; 1024];
441                let _ = sock.read(&mut buf).await;
442                let header =
443                    "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n";
444                let _ = sock.write_all(header.as_bytes()).await;
445                let chunk = vec![b'a'; 64 * 1024];
446                // Write well past the 1 MiB cap used below.
447                for _ in 0..40 {
448                    if sock.write_all(&chunk).await.is_err() {
449                        break;
450                    }
451                }
452                let _ = sock.shutdown().await;
453            }
454        });
455
456        let client = reqwest::Client::builder().build()?;
457        let mut response = client.get(format!("http://{addr}/big")).send().await?;
458        let result = read_capped_body(&mut response, 1024 * 1024).await;
459        server.abort();
460
461        assert!(result.is_err(), "oversized streamed body must be rejected");
462        let msg = result.unwrap_err().to_string();
463        assert!(
464            msg.contains("Content too large"),
465            "expected size-cap error, got: {msg}"
466        );
467        Ok(())
468    }
469
470    /// A body within the cap must be returned in full.
471    #[tokio::test]
472    async fn test_read_capped_body_accepts_small_stream() -> Result<()> {
473        use tokio::io::{AsyncReadExt, AsyncWriteExt};
474        use tokio::net::TcpListener;
475
476        let listener = TcpListener::bind("127.0.0.1:0").await?;
477        let addr = listener.local_addr()?;
478
479        let server = tokio::spawn(async move {
480            if let Ok((mut sock, _)) = listener.accept().await {
481                let mut buf = [0u8; 1024];
482                let _ = sock.read(&mut buf).await;
483                let body = "hello world";
484                let resp = format!(
485                    "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
486                    body.len()
487                );
488                let _ = sock.write_all(resp.as_bytes()).await;
489                let _ = sock.shutdown().await;
490            }
491        });
492
493        let client = reqwest::Client::builder().build()?;
494        let mut response = client.get(format!("http://{addr}/small")).send().await?;
495        let bytes = read_capped_body(&mut response, 1024 * 1024).await?;
496        server.abort();
497
498        assert_eq!(String::from_utf8_lossy(&bytes), "hello world");
499        Ok(())
500    }
501}