1use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result, bail};
6use serde_json::{Value, json};
7use std::time::Duration;
8
9use super::security::UrlValidator;
10
11const MAX_CONTENT_SIZE: usize = 1024 * 1024;
13
14const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
16
17#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
19pub enum FetchFormat {
20 #[default]
22 Text,
23 Markdown,
25}
26
27impl FetchFormat {
28 fn from_str(s: &str) -> Option<Self> {
29 match s.to_lowercase().as_str() {
30 "text" => Some(Self::Text),
31 "markdown" | "md" => Some(Self::Markdown),
32 _ => None,
33 }
34 }
35}
36
37pub struct LinkFetchTool {
53 client: reqwest::Client,
54 validator: UrlValidator,
55 default_format: FetchFormat,
56}
57
58impl Default for LinkFetchTool {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl LinkFetchTool {
65 #[must_use]
71 pub fn new() -> Self {
72 let client = reqwest::Client::builder()
73 .redirect(reqwest::redirect::Policy::none())
74 .timeout(DEFAULT_TIMEOUT)
75 .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)")
76 .build()
77 .expect("Failed to build HTTP client");
78
79 Self {
80 client,
81 validator: UrlValidator::new(),
82 default_format: FetchFormat::Text,
83 }
84 }
85
86 #[must_use]
88 pub fn with_validator(mut self, validator: UrlValidator) -> Self {
89 self.validator = validator;
90 self
91 }
92
93 #[must_use]
95 pub fn with_client(mut self, client: reqwest::Client) -> Self {
96 self.client = client;
97 self
98 }
99
100 #[must_use]
102 pub const fn with_default_format(mut self, format: FetchFormat) -> Self {
103 self.default_format = format;
104 self
105 }
106
107 async fn fetch_url(&self, url_str: &str, format: FetchFormat) -> Result<String> {
112 let mut url = self.validator.validate(url_str)?;
114 let max_redirects = self.validator.max_redirects();
115
116 let mut response = self
117 .client
118 .get(url.as_str())
119 .send()
120 .await
121 .context("Failed to fetch URL")?;
122
123 let mut redirects = 0;
125 while response.status().is_redirection() {
126 redirects += 1;
127 if redirects > max_redirects {
128 bail!("Too many redirects ({redirects} > {max_redirects})");
129 }
130
131 let location = response
132 .headers()
133 .get(reqwest::header::LOCATION)
134 .context("Redirect response missing Location header")?
135 .to_str()
136 .context("Invalid Location header")?;
137
138 let redirect_url_str = url
140 .join(location)
141 .map_or_else(|_| location.to_string(), |u| u.to_string());
142
143 url = self.validator.validate(&redirect_url_str)?;
145
146 response = self
147 .client
148 .get(url.as_str())
149 .send()
150 .await
151 .context("Failed to follow redirect")?;
152 }
153
154 if !response.status().is_success() {
156 bail!("HTTP error: {}", response.status());
157 }
158
159 if let Some(len) = response.content_length()
161 && len > MAX_CONTENT_SIZE as u64
162 {
163 bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
164 }
165
166 let content_type = response
168 .headers()
169 .get(reqwest::header::CONTENT_TYPE)
170 .and_then(|v| v.to_str().ok())
171 .unwrap_or("text/html")
172 .to_string();
173
174 let bytes = response
176 .bytes()
177 .await
178 .context("Failed to read response body")?;
179
180 if bytes.len() > MAX_CONTENT_SIZE {
181 bail!(
182 "Content too large: {} bytes (max {} bytes)",
183 bytes.len(),
184 MAX_CONTENT_SIZE
185 );
186 }
187
188 let html = String::from_utf8_lossy(&bytes);
190
191 if content_type.contains("text/html") || content_type.contains("application/xhtml") {
193 Ok(convert_html(&html, format))
194 } else if content_type.contains("text/plain") {
195 Ok(html.into_owned())
196 } else {
197 Ok(html.into_owned())
199 }
200 }
201}
202
203fn convert_html(html: &str, format: FetchFormat) -> String {
205 let result = match format {
206 FetchFormat::Text => {
207 html2text::from_read(html.as_bytes(), 80)
209 }
210 FetchFormat::Markdown => {
211 html2text::from_read(html.as_bytes(), 80)
213 }
214 };
215 result.unwrap_or_else(|_| html.to_string())
216}
217
218impl<Ctx> Tool<Ctx> for LinkFetchTool
219where
220 Ctx: Send + Sync + 'static,
221{
222 type Name = PrimitiveToolName;
223
224 fn name(&self) -> PrimitiveToolName {
225 PrimitiveToolName::LinkFetch
226 }
227
228 fn display_name(&self) -> &'static str {
229 "Fetch URL"
230 }
231
232 fn description(&self) -> &'static str {
233 "Fetch and read web page content. Returns the page content as text or markdown. \
234 Includes SSRF protection to prevent access to internal resources."
235 }
236
237 fn input_schema(&self) -> Value {
238 json!({
239 "type": "object",
240 "properties": {
241 "url": {
242 "type": "string",
243 "description": "The URL to fetch (must be HTTPS)"
244 },
245 "format": {
246 "type": "string",
247 "enum": ["text", "markdown"],
248 "description": "Output format (default: text)"
249 }
250 },
251 "required": ["url"]
252 })
253 }
254
255 fn tier(&self) -> ToolTier {
256 ToolTier::Observe
258 }
259
260 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
261 let url = input
262 .get("url")
263 .and_then(Value::as_str)
264 .context("Missing 'url' parameter")?;
265
266 let format = input
267 .get("format")
268 .and_then(Value::as_str)
269 .and_then(FetchFormat::from_str)
270 .unwrap_or(self.default_format);
271
272 match self.fetch_url(url, format).await {
273 Ok(content) => Ok(ToolResult {
274 success: true,
275 output: content,
276 data: Some(json!({ "url": url, "format": format_name(format) })),
277 documents: Vec::new(),
278 duration_ms: None,
279 }),
280 Err(e) => Ok(ToolResult {
281 success: false,
282 output: format!("Failed to fetch URL: {e}"),
283 data: Some(json!({ "url": url, "error": e.to_string() })),
284 documents: Vec::new(),
285 duration_ms: None,
286 }),
287 }
288 }
289}
290
291const fn format_name(format: FetchFormat) -> &'static str {
293 match format {
294 FetchFormat::Text => "text",
295 FetchFormat::Markdown => "markdown",
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_link_fetch_tool_metadata() {
305 let tool = LinkFetchTool::new();
306
307 assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::LinkFetch);
308 assert!(Tool::<()>::description(&tool).contains("Fetch"));
309 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
310 }
311
312 #[test]
313 fn test_link_fetch_tool_input_schema() {
314 let tool = LinkFetchTool::new();
315
316 let schema = Tool::<()>::input_schema(&tool);
317 assert_eq!(schema["type"], "object");
318 assert!(schema["properties"]["url"].is_object());
319 assert!(schema["properties"]["format"].is_object());
320 assert!(
321 schema["required"]
322 .as_array()
323 .is_some_and(|arr| arr.iter().any(|v| v == "url"))
324 );
325 }
326
327 #[test]
328 fn test_format_from_str() {
329 assert_eq!(FetchFormat::from_str("text"), Some(FetchFormat::Text));
330 assert_eq!(FetchFormat::from_str("TEXT"), Some(FetchFormat::Text));
331 assert_eq!(
332 FetchFormat::from_str("markdown"),
333 Some(FetchFormat::Markdown)
334 );
335 assert_eq!(FetchFormat::from_str("md"), Some(FetchFormat::Markdown));
336 assert_eq!(FetchFormat::from_str("invalid"), None);
337 }
338
339 #[test]
340 fn test_convert_html_text() {
341 let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
342 let result = convert_html(html, FetchFormat::Text);
343 assert!(result.contains("Title"));
344 assert!(result.contains("Paragraph"));
345 }
346
347 #[test]
348 fn test_default_format() {
349 let tool = LinkFetchTool::new();
350 assert_eq!(tool.default_format, FetchFormat::Text);
351
352 let tool = LinkFetchTool::new().with_default_format(FetchFormat::Markdown);
353 assert_eq!(tool.default_format, FetchFormat::Markdown);
354 }
355
356 #[tokio::test]
357 async fn test_link_fetch_blocked_url() {
358 let tool = LinkFetchTool::new();
359 let ctx = ToolContext::new(());
360 let input = json!({ "url": "http://localhost:8080" });
361
362 let result = Tool::<()>::execute(&tool, &ctx, input).await;
363 assert!(result.is_ok());
364
365 let tool_result = result.expect("Should succeed");
366 assert!(!tool_result.success);
367 assert!(
368 tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
369 );
370 }
371
372 #[tokio::test]
373 async fn test_link_fetch_missing_url() {
374 let tool = LinkFetchTool::new();
375 let ctx = ToolContext::new(());
376 let input = json!({});
377
378 let result = Tool::<()>::execute(&tool, &ctx, input).await;
379 assert!(result.is_err());
380 assert!(result.unwrap_err().to_string().contains("url"));
381 }
382
383 #[tokio::test]
384 async fn test_link_fetch_invalid_url() {
385 let tool = LinkFetchTool::new();
386 let ctx = ToolContext::new(());
387 let input = json!({ "url": "not-a-valid-url" });
388
389 let result = Tool::<()>::execute(&tool, &ctx, input).await;
390 assert!(result.is_ok());
391
392 let tool_result = result.expect("Should succeed");
393 assert!(!tool_result.success);
394 assert!(tool_result.output.contains("Invalid URL"));
395 }
396
397 #[test]
398 fn test_with_validator() {
399 let validator = UrlValidator::new().with_allow_http();
400 let _tool = LinkFetchTool::new().with_validator(validator);
401 }
403
404 #[test]
405 fn test_format_name() {
406 assert_eq!(format_name(FetchFormat::Text), "text");
407 assert_eq!(format_name(FetchFormat::Markdown), "markdown");
408 }
409
410 #[test]
411 fn test_redirects_disabled_in_client() {
412 let tool = LinkFetchTool::new();
415 assert_eq!(tool.validator.max_redirects(), 3);
418 }
419
420 #[tokio::test]
421 async fn test_redirect_to_private_ip_blocked() {
422 let validator = UrlValidator::new().with_allow_http();
425
426 let result = validator.validate("http://169.254.169.254/latest/meta-data/");
428 assert!(result.is_err());
429 assert!(result.unwrap_err().to_string().contains("blocked"));
430
431 let result = validator.validate("http://10.0.0.1/internal");
433 assert!(result.is_err());
434 }
435
436 #[tokio::test]
437 async fn test_redirect_to_localhost_blocked() {
438 let validator = UrlValidator::new().with_allow_http();
439
440 let result = validator.validate("http://127.0.0.1/admin");
442 assert!(result.is_err());
443 }
444}