1use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result, bail};
6use serde_json::{Value, json};
7use std::time::Duration;
8
9use super::security::UrlValidator;
10
11const MAX_CONTENT_SIZE: usize = 1024 * 1024;
13
14const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
16
17#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
19pub enum FetchFormat {
20 #[default]
22 Text,
23 Markdown,
25}
26
27impl FetchFormat {
28 fn from_str(s: &str) -> Option<Self> {
29 match s.to_lowercase().as_str() {
30 "text" => Some(Self::Text),
31 "markdown" | "md" => Some(Self::Markdown),
32 _ => None,
33 }
34 }
35}
36
37pub struct LinkFetchTool {
53 client: reqwest::Client,
54 validator: UrlValidator,
55 default_format: FetchFormat,
56}
57
58impl Default for LinkFetchTool {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl LinkFetchTool {
65 #[must_use]
71 pub fn new() -> Self {
72 let client = reqwest::Client::builder()
73 .timeout(DEFAULT_TIMEOUT)
74 .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)")
75 .build()
76 .expect("Failed to build HTTP client");
77
78 Self {
79 client,
80 validator: UrlValidator::new(),
81 default_format: FetchFormat::Text,
82 }
83 }
84
85 #[must_use]
87 pub fn with_validator(mut self, validator: UrlValidator) -> Self {
88 self.validator = validator;
89 self
90 }
91
92 #[must_use]
94 pub fn with_client(mut self, client: reqwest::Client) -> Self {
95 self.client = client;
96 self
97 }
98
99 #[must_use]
101 pub const fn with_default_format(mut self, format: FetchFormat) -> Self {
102 self.default_format = format;
103 self
104 }
105
106 async fn fetch_url(&self, url_str: &str, format: FetchFormat) -> Result<String> {
108 let url = self.validator.validate(url_str)?;
110
111 let response = self
113 .client
114 .get(url.as_str())
115 .send()
116 .await
117 .context("Failed to fetch URL")?;
118
119 if !response.status().is_success() {
121 bail!("HTTP error: {}", response.status());
122 }
123
124 if let Some(len) = response.content_length()
126 && len > MAX_CONTENT_SIZE as u64
127 {
128 bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
129 }
130
131 let content_type = response
133 .headers()
134 .get(reqwest::header::CONTENT_TYPE)
135 .and_then(|v| v.to_str().ok())
136 .unwrap_or("text/html")
137 .to_string();
138
139 let bytes = response
141 .bytes()
142 .await
143 .context("Failed to read response body")?;
144
145 if bytes.len() > MAX_CONTENT_SIZE {
146 bail!(
147 "Content too large: {} bytes (max {} bytes)",
148 bytes.len(),
149 MAX_CONTENT_SIZE
150 );
151 }
152
153 let html = String::from_utf8_lossy(&bytes);
155
156 if content_type.contains("text/html") || content_type.contains("application/xhtml") {
158 Ok(convert_html(&html, format))
159 } else if content_type.contains("text/plain") {
160 Ok(html.into_owned())
161 } else {
162 Ok(html.into_owned())
164 }
165 }
166}
167
168fn convert_html(html: &str, format: FetchFormat) -> String {
170 let result = match format {
171 FetchFormat::Text => {
172 html2text::from_read(html.as_bytes(), 80)
174 }
175 FetchFormat::Markdown => {
176 html2text::from_read(html.as_bytes(), 80)
178 }
179 };
180 result.unwrap_or_else(|_| html.to_string())
181}
182
183impl<Ctx> Tool<Ctx> for LinkFetchTool
184where
185 Ctx: Send + Sync + 'static,
186{
187 type Name = PrimitiveToolName;
188
189 fn name(&self) -> PrimitiveToolName {
190 PrimitiveToolName::LinkFetch
191 }
192
193 fn display_name(&self) -> &'static str {
194 "Fetch URL"
195 }
196
197 fn description(&self) -> &'static str {
198 "Fetch and read web page content. Returns the page content as text or markdown. \
199 Includes SSRF protection to prevent access to internal resources."
200 }
201
202 fn input_schema(&self) -> Value {
203 json!({
204 "type": "object",
205 "properties": {
206 "url": {
207 "type": "string",
208 "description": "The URL to fetch (must be HTTPS)"
209 },
210 "format": {
211 "type": "string",
212 "enum": ["text", "markdown"],
213 "description": "Output format (default: text)"
214 }
215 },
216 "required": ["url"]
217 })
218 }
219
220 fn tier(&self) -> ToolTier {
221 ToolTier::Observe
223 }
224
225 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
226 let url = input
227 .get("url")
228 .and_then(Value::as_str)
229 .context("Missing 'url' parameter")?;
230
231 let format = input
232 .get("format")
233 .and_then(Value::as_str)
234 .and_then(FetchFormat::from_str)
235 .unwrap_or(self.default_format);
236
237 match self.fetch_url(url, format).await {
238 Ok(content) => Ok(ToolResult {
239 success: true,
240 output: content,
241 data: Some(json!({ "url": url, "format": format_name(format) })),
242 duration_ms: None,
243 }),
244 Err(e) => Ok(ToolResult {
245 success: false,
246 output: format!("Failed to fetch URL: {e}"),
247 data: Some(json!({ "url": url, "error": e.to_string() })),
248 duration_ms: None,
249 }),
250 }
251 }
252}
253
254const fn format_name(format: FetchFormat) -> &'static str {
256 match format {
257 FetchFormat::Text => "text",
258 FetchFormat::Markdown => "markdown",
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_link_fetch_tool_metadata() {
268 let tool = LinkFetchTool::new();
269
270 assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::LinkFetch);
271 assert!(Tool::<()>::description(&tool).contains("Fetch"));
272 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
273 }
274
275 #[test]
276 fn test_link_fetch_tool_input_schema() {
277 let tool = LinkFetchTool::new();
278
279 let schema = Tool::<()>::input_schema(&tool);
280 assert_eq!(schema["type"], "object");
281 assert!(schema["properties"]["url"].is_object());
282 assert!(schema["properties"]["format"].is_object());
283 assert!(
284 schema["required"]
285 .as_array()
286 .is_some_and(|arr| arr.iter().any(|v| v == "url"))
287 );
288 }
289
290 #[test]
291 fn test_format_from_str() {
292 assert_eq!(FetchFormat::from_str("text"), Some(FetchFormat::Text));
293 assert_eq!(FetchFormat::from_str("TEXT"), Some(FetchFormat::Text));
294 assert_eq!(
295 FetchFormat::from_str("markdown"),
296 Some(FetchFormat::Markdown)
297 );
298 assert_eq!(FetchFormat::from_str("md"), Some(FetchFormat::Markdown));
299 assert_eq!(FetchFormat::from_str("invalid"), None);
300 }
301
302 #[test]
303 fn test_convert_html_text() {
304 let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
305 let result = convert_html(html, FetchFormat::Text);
306 assert!(result.contains("Title"));
307 assert!(result.contains("Paragraph"));
308 }
309
310 #[test]
311 fn test_default_format() {
312 let tool = LinkFetchTool::new();
313 assert_eq!(tool.default_format, FetchFormat::Text);
314
315 let tool = LinkFetchTool::new().with_default_format(FetchFormat::Markdown);
316 assert_eq!(tool.default_format, FetchFormat::Markdown);
317 }
318
319 #[tokio::test]
320 async fn test_link_fetch_blocked_url() {
321 let tool = LinkFetchTool::new();
322 let ctx = ToolContext::new(());
323 let input = json!({ "url": "http://localhost:8080" });
324
325 let result = Tool::<()>::execute(&tool, &ctx, input).await;
326 assert!(result.is_ok());
327
328 let tool_result = result.expect("Should succeed");
329 assert!(!tool_result.success);
330 assert!(
331 tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
332 );
333 }
334
335 #[tokio::test]
336 async fn test_link_fetch_missing_url() {
337 let tool = LinkFetchTool::new();
338 let ctx = ToolContext::new(());
339 let input = json!({});
340
341 let result = Tool::<()>::execute(&tool, &ctx, input).await;
342 assert!(result.is_err());
343 assert!(result.unwrap_err().to_string().contains("url"));
344 }
345
346 #[tokio::test]
347 async fn test_link_fetch_invalid_url() {
348 let tool = LinkFetchTool::new();
349 let ctx = ToolContext::new(());
350 let input = json!({ "url": "not-a-valid-url" });
351
352 let result = Tool::<()>::execute(&tool, &ctx, input).await;
353 assert!(result.is_ok());
354
355 let tool_result = result.expect("Should succeed");
356 assert!(!tool_result.success);
357 assert!(tool_result.output.contains("Invalid URL"));
358 }
359
360 #[test]
361 fn test_with_validator() {
362 let validator = UrlValidator::new().with_allow_http();
363 let _tool = LinkFetchTool::new().with_validator(validator);
364 }
366
367 #[test]
368 fn test_format_name() {
369 assert_eq!(format_name(FetchFormat::Text), "text");
370 assert_eq!(format_name(FetchFormat::Markdown), "markdown");
371 }
372}