1use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result, bail};
6use serde_json::{Value, json};
7use std::time::Duration;
8
9use super::security::UrlValidator;
10
11const MAX_CONTENT_SIZE: usize = 1024 * 1024;
13
14const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
16
17#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
24pub enum FetchFormat {
25 #[default]
27 Text,
28}
29
30pub struct LinkFetchTool {
46 client: Option<reqwest::Client>,
55 validator: UrlValidator,
56}
57
58impl Default for LinkFetchTool {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl LinkFetchTool {
65 #[must_use]
71 pub fn new() -> Self {
72 Self {
73 client: None,
74 validator: UrlValidator::new(),
75 }
76 }
77
78 #[must_use]
80 pub fn with_validator(mut self, validator: UrlValidator) -> Self {
81 self.validator = validator;
82 self
83 }
84
85 #[must_use]
91 pub fn with_client(mut self, client: reqwest::Client) -> Self {
92 self.client = Some(client);
93 self
94 }
95
96 fn build_client(
103 &self,
104 host: Option<&str>,
105 addrs: &[std::net::SocketAddr],
106 ) -> Result<reqwest::Client> {
107 if let Some(client) = &self.client {
108 return Ok(client.clone());
109 }
110
111 let mut builder = reqwest::Client::builder()
112 .redirect(reqwest::redirect::Policy::none())
113 .timeout(DEFAULT_TIMEOUT)
114 .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)");
115
116 if let Some(host) = host
117 && !addrs.is_empty()
118 {
119 builder = builder.resolve_to_addrs(host, addrs);
120 }
121
122 builder.build().context("Failed to build HTTP client")
123 }
124
125 async fn fetch_url(&self, url_str: &str) -> Result<String> {
131 let mut validated = self.validator.validate(url_str).await?;
133 let max_redirects = self.validator.max_redirects();
134
135 let client = self.build_client(validated.url.host_str(), &validated.addresses)?;
136 let mut response = client
137 .get(validated.url.as_str())
138 .send()
139 .await
140 .context("Failed to fetch URL")?;
141
142 let mut redirects = 0;
144 while response.status().is_redirection() {
145 redirects += 1;
146 if redirects > max_redirects {
147 bail!("Too many redirects ({redirects} > {max_redirects})");
148 }
149
150 let location = response
151 .headers()
152 .get(reqwest::header::LOCATION)
153 .context("Redirect response missing Location header")?
154 .to_str()
155 .context("Invalid Location header")?;
156
157 let redirect_url_str = validated
159 .url
160 .join(location)
161 .map_or_else(|_| location.to_string(), |u| u.to_string());
162
163 validated = self.validator.validate(&redirect_url_str).await?;
166
167 let client = self.build_client(validated.url.host_str(), &validated.addresses)?;
168 response = client
169 .get(validated.url.as_str())
170 .send()
171 .await
172 .context("Failed to follow redirect")?;
173 }
174
175 if !response.status().is_success() {
177 bail!("HTTP error: {}", response.status());
178 }
179
180 if let Some(len) = response.content_length()
182 && len > MAX_CONTENT_SIZE as u64
183 {
184 bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
185 }
186
187 let content_type = response
189 .headers()
190 .get(reqwest::header::CONTENT_TYPE)
191 .and_then(|v| v.to_str().ok())
192 .unwrap_or("text/html")
193 .to_string();
194
195 let bytes = read_capped_body(&mut response, MAX_CONTENT_SIZE).await?;
200
201 let html = String::from_utf8_lossy(&bytes);
203
204 if content_type.contains("text/html") || content_type.contains("application/xhtml") {
206 Ok(convert_html(&html))
207 } else if content_type.contains("text/plain") {
208 Ok(html.into_owned())
209 } else {
210 Ok(html.into_owned())
212 }
213 }
214}
215
216fn convert_html(html: &str) -> String {
218 html2text::from_read(html.as_bytes(), 80).unwrap_or_else(|_| html.to_string())
219}
220
221async fn read_capped_body(response: &mut reqwest::Response, max: usize) -> Result<Vec<u8>> {
227 let mut bytes: Vec<u8> = Vec::new();
228 while let Some(chunk) = response
229 .chunk()
230 .await
231 .context("Failed to read response body")?
232 {
233 if bytes.len() + chunk.len() > max {
234 bail!("Content too large: exceeds {max} bytes");
235 }
236 bytes.extend_from_slice(&chunk);
237 }
238 Ok(bytes)
239}
240
241impl<Ctx> Tool<Ctx> for LinkFetchTool
242where
243 Ctx: Send + Sync + 'static,
244{
245 type Name = PrimitiveToolName;
246
247 fn name(&self) -> PrimitiveToolName {
248 PrimitiveToolName::LinkFetch
249 }
250
251 fn display_name(&self) -> &'static str {
252 "Fetch URL"
253 }
254
255 fn description(&self) -> &'static str {
256 "Fetch and read web page content. Returns the page content as text or markdown. \
257 Includes SSRF protection to prevent access to internal resources."
258 }
259
260 fn input_schema(&self) -> Value {
261 json!({
262 "type": "object",
263 "properties": {
264 "url": {
265 "type": "string",
266 "description": "The URL to fetch (must be HTTPS)"
267 }
268 },
269 "required": ["url"]
270 })
271 }
272
273 fn tier(&self) -> ToolTier {
274 ToolTier::Observe
276 }
277
278 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
279 let url = input
280 .get("url")
281 .and_then(Value::as_str)
282 .context("Missing 'url' parameter")?;
283
284 match self.fetch_url(url).await {
285 Ok(content) => Ok(ToolResult {
286 success: true,
287 output: content,
288 data: Some(json!({ "url": url })),
289 documents: Vec::new(),
290 duration_ms: None,
291 }),
292 Err(e) => Ok(ToolResult {
293 success: false,
294 output: format!("Failed to fetch URL: {e}"),
295 data: Some(json!({ "url": url, "error": e.to_string() })),
296 documents: Vec::new(),
297 duration_ms: None,
298 }),
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_link_fetch_tool_metadata() {
309 let tool = LinkFetchTool::new();
310
311 assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::LinkFetch);
312 assert!(Tool::<()>::description(&tool).contains("Fetch"));
313 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
314 }
315
316 #[test]
317 fn test_link_fetch_tool_input_schema() {
318 let tool = LinkFetchTool::new();
319
320 let schema = Tool::<()>::input_schema(&tool);
321 assert_eq!(schema["type"], "object");
322 assert!(schema["properties"]["url"].is_object());
323 assert!(schema["properties"]["format"].is_null());
325 assert!(
326 schema["required"]
327 .as_array()
328 .is_some_and(|arr| arr.iter().any(|v| v == "url"))
329 );
330 }
331
332 #[test]
333 fn test_convert_html_text() {
334 let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
335 let result = convert_html(html);
336 assert!(result.contains("Title"));
337 assert!(result.contains("Paragraph"));
338 }
339
340 #[tokio::test]
341 async fn test_link_fetch_blocked_url() {
342 let tool = LinkFetchTool::new();
343 let ctx = ToolContext::new(());
344 let input = json!({ "url": "http://localhost:8080" });
345
346 let result = Tool::<()>::execute(&tool, &ctx, input).await;
347 assert!(result.is_ok());
348
349 let tool_result = result.expect("Should succeed");
350 assert!(!tool_result.success);
351 assert!(
352 tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
353 );
354 }
355
356 #[tokio::test]
357 async fn test_link_fetch_missing_url() {
358 let tool = LinkFetchTool::new();
359 let ctx = ToolContext::new(());
360 let input = json!({});
361
362 let result = Tool::<()>::execute(&tool, &ctx, input).await;
363 assert!(result.is_err());
364 assert!(result.unwrap_err().to_string().contains("url"));
365 }
366
367 #[tokio::test]
368 async fn test_link_fetch_invalid_url() {
369 let tool = LinkFetchTool::new();
370 let ctx = ToolContext::new(());
371 let input = json!({ "url": "not-a-valid-url" });
372
373 let result = Tool::<()>::execute(&tool, &ctx, input).await;
374 assert!(result.is_ok());
375
376 let tool_result = result.expect("Should succeed");
377 assert!(!tool_result.success);
378 assert!(tool_result.output.contains("Invalid URL"));
379 }
380
381 #[test]
382 fn test_with_validator() {
383 let validator = UrlValidator::new().with_allow_http();
384 let _tool = LinkFetchTool::new().with_validator(validator);
385 }
387
388 #[test]
389 fn test_redirects_disabled_in_client() {
390 let tool = LinkFetchTool::new();
393 assert_eq!(tool.validator.max_redirects(), 3);
396 }
397
398 #[tokio::test]
399 async fn test_redirect_to_private_ip_blocked() {
400 let validator = UrlValidator::new().with_allow_http();
403
404 let result = validator
406 .validate("http://169.254.169.254/latest/meta-data/")
407 .await;
408 assert!(result.is_err());
409 assert!(result.unwrap_err().to_string().contains("blocked"));
410
411 let result = validator.validate("http://10.0.0.1/internal").await;
413 assert!(result.is_err());
414 }
415
416 #[tokio::test]
417 async fn test_redirect_to_localhost_blocked() {
418 let validator = UrlValidator::new().with_allow_http();
419
420 let result = validator.validate("http://127.0.0.1/admin").await;
422 assert!(result.is_err());
423 }
424
425 #[tokio::test]
431 async fn test_read_capped_body_rejects_oversized_stream() -> Result<()> {
432 use tokio::io::{AsyncReadExt, AsyncWriteExt};
433 use tokio::net::TcpListener;
434
435 let listener = TcpListener::bind("127.0.0.1:0").await?;
436 let addr = listener.local_addr()?;
437
438 let server = tokio::spawn(async move {
439 if let Ok((mut sock, _)) = listener.accept().await {
440 let mut buf = [0u8; 1024];
441 let _ = sock.read(&mut buf).await;
442 let header =
443 "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n";
444 let _ = sock.write_all(header.as_bytes()).await;
445 let chunk = vec![b'a'; 64 * 1024];
446 for _ in 0..40 {
448 if sock.write_all(&chunk).await.is_err() {
449 break;
450 }
451 }
452 let _ = sock.shutdown().await;
453 }
454 });
455
456 let client = reqwest::Client::builder().build()?;
457 let mut response = client.get(format!("http://{addr}/big")).send().await?;
458 let result = read_capped_body(&mut response, 1024 * 1024).await;
459 server.abort();
460
461 assert!(result.is_err(), "oversized streamed body must be rejected");
462 let msg = result.unwrap_err().to_string();
463 assert!(
464 msg.contains("Content too large"),
465 "expected size-cap error, got: {msg}"
466 );
467 Ok(())
468 }
469
470 #[tokio::test]
472 async fn test_read_capped_body_accepts_small_stream() -> Result<()> {
473 use tokio::io::{AsyncReadExt, AsyncWriteExt};
474 use tokio::net::TcpListener;
475
476 let listener = TcpListener::bind("127.0.0.1:0").await?;
477 let addr = listener.local_addr()?;
478
479 let server = tokio::spawn(async move {
480 if let Ok((mut sock, _)) = listener.accept().await {
481 let mut buf = [0u8; 1024];
482 let _ = sock.read(&mut buf).await;
483 let body = "hello world";
484 let resp = format!(
485 "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
486 body.len()
487 );
488 let _ = sock.write_all(resp.as_bytes()).await;
489 let _ = sock.shutdown().await;
490 }
491 });
492
493 let client = reqwest::Client::builder().build()?;
494 let mut response = client.get(format!("http://{addr}/small")).send().await?;
495 let bytes = read_capped_body(&mut response, 1024 * 1024).await?;
496 server.abort();
497
498 assert_eq!(String::from_utf8_lossy(&bytes), "hello world");
499 Ok(())
500 }
501}