1use crate::tools::{Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10use super::security::UrlValidator;
11
12const MAX_CONTENT_SIZE: usize = 1024 * 1024;
14
15const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
17
18#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
20pub enum FetchFormat {
21 #[default]
23 Text,
24 Markdown,
26}
27
28impl FetchFormat {
29 fn from_str(s: &str) -> Option<Self> {
30 match s.to_lowercase().as_str() {
31 "text" => Some(Self::Text),
32 "markdown" | "md" => Some(Self::Markdown),
33 _ => None,
34 }
35 }
36}
37
38pub struct LinkFetchTool {
54 client: reqwest::Client,
55 validator: UrlValidator,
56 default_format: FetchFormat,
57}
58
59impl Default for LinkFetchTool {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl LinkFetchTool {
66 #[must_use]
72 pub fn new() -> Self {
73 let client = reqwest::Client::builder()
74 .timeout(DEFAULT_TIMEOUT)
75 .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)")
76 .build()
77 .expect("Failed to build HTTP client");
78
79 Self {
80 client,
81 validator: UrlValidator::new(),
82 default_format: FetchFormat::Text,
83 }
84 }
85
86 #[must_use]
88 pub fn with_validator(mut self, validator: UrlValidator) -> Self {
89 self.validator = validator;
90 self
91 }
92
93 #[must_use]
95 pub fn with_client(mut self, client: reqwest::Client) -> Self {
96 self.client = client;
97 self
98 }
99
100 #[must_use]
102 pub const fn with_default_format(mut self, format: FetchFormat) -> Self {
103 self.default_format = format;
104 self
105 }
106
107 async fn fetch_url(&self, url_str: &str, format: FetchFormat) -> Result<String> {
109 let url = self.validator.validate(url_str)?;
111
112 let response = self
114 .client
115 .get(url.as_str())
116 .send()
117 .await
118 .context("Failed to fetch URL")?;
119
120 if !response.status().is_success() {
122 bail!("HTTP error: {}", response.status());
123 }
124
125 if let Some(len) = response.content_length()
127 && len > MAX_CONTENT_SIZE as u64
128 {
129 bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
130 }
131
132 let content_type = response
134 .headers()
135 .get(reqwest::header::CONTENT_TYPE)
136 .and_then(|v| v.to_str().ok())
137 .unwrap_or("text/html")
138 .to_string();
139
140 let bytes = response
142 .bytes()
143 .await
144 .context("Failed to read response body")?;
145
146 if bytes.len() > MAX_CONTENT_SIZE {
147 bail!(
148 "Content too large: {} bytes (max {} bytes)",
149 bytes.len(),
150 MAX_CONTENT_SIZE
151 );
152 }
153
154 let html = String::from_utf8_lossy(&bytes);
156
157 if content_type.contains("text/html") || content_type.contains("application/xhtml") {
159 Ok(convert_html(&html, format))
160 } else if content_type.contains("text/plain") {
161 Ok(html.into_owned())
162 } else {
163 Ok(html.into_owned())
165 }
166 }
167}
168
169fn convert_html(html: &str, format: FetchFormat) -> String {
171 let result = match format {
172 FetchFormat::Text => {
173 html2text::from_read(html.as_bytes(), 80)
175 }
176 FetchFormat::Markdown => {
177 html2text::from_read(html.as_bytes(), 80)
179 }
180 };
181 result.unwrap_or_else(|_| html.to_string())
182}
183
184#[async_trait]
185impl<Ctx> Tool<Ctx> for LinkFetchTool
186where
187 Ctx: Send + Sync + 'static,
188{
189 fn name(&self) -> &'static str {
190 "link_fetch"
191 }
192
193 fn description(&self) -> &'static str {
194 "Fetch and read web page content. Returns the page content as text or markdown. \
195 Includes SSRF protection to prevent access to internal resources."
196 }
197
198 fn input_schema(&self) -> Value {
199 json!({
200 "type": "object",
201 "properties": {
202 "url": {
203 "type": "string",
204 "description": "The URL to fetch (must be HTTPS)"
205 },
206 "format": {
207 "type": "string",
208 "enum": ["text", "markdown"],
209 "description": "Output format (default: text)"
210 }
211 },
212 "required": ["url"]
213 })
214 }
215
216 fn tier(&self) -> ToolTier {
217 ToolTier::Observe
219 }
220
221 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
222 let url = input
223 .get("url")
224 .and_then(Value::as_str)
225 .context("Missing 'url' parameter")?;
226
227 let format = input
228 .get("format")
229 .and_then(Value::as_str)
230 .and_then(FetchFormat::from_str)
231 .unwrap_or(self.default_format);
232
233 match self.fetch_url(url, format).await {
234 Ok(content) => Ok(ToolResult {
235 success: true,
236 output: content,
237 data: Some(json!({ "url": url, "format": format_name(format) })),
238 duration_ms: None,
239 }),
240 Err(e) => Ok(ToolResult {
241 success: false,
242 output: format!("Failed to fetch URL: {e}"),
243 data: Some(json!({ "url": url, "error": e.to_string() })),
244 duration_ms: None,
245 }),
246 }
247 }
248}
249
250const fn format_name(format: FetchFormat) -> &'static str {
252 match format {
253 FetchFormat::Text => "text",
254 FetchFormat::Markdown => "markdown",
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_link_fetch_tool_metadata() {
264 let tool = LinkFetchTool::new();
265
266 assert_eq!(Tool::<()>::name(&tool), "link_fetch");
267 assert!(Tool::<()>::description(&tool).contains("Fetch"));
268 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
269 }
270
271 #[test]
272 fn test_link_fetch_tool_input_schema() {
273 let tool = LinkFetchTool::new();
274
275 let schema = Tool::<()>::input_schema(&tool);
276 assert_eq!(schema["type"], "object");
277 assert!(schema["properties"]["url"].is_object());
278 assert!(schema["properties"]["format"].is_object());
279 assert!(
280 schema["required"]
281 .as_array()
282 .is_some_and(|arr| arr.iter().any(|v| v == "url"))
283 );
284 }
285
286 #[test]
287 fn test_format_from_str() {
288 assert_eq!(FetchFormat::from_str("text"), Some(FetchFormat::Text));
289 assert_eq!(FetchFormat::from_str("TEXT"), Some(FetchFormat::Text));
290 assert_eq!(
291 FetchFormat::from_str("markdown"),
292 Some(FetchFormat::Markdown)
293 );
294 assert_eq!(FetchFormat::from_str("md"), Some(FetchFormat::Markdown));
295 assert_eq!(FetchFormat::from_str("invalid"), None);
296 }
297
298 #[test]
299 fn test_convert_html_text() {
300 let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
301 let result = convert_html(html, FetchFormat::Text);
302 assert!(result.contains("Title"));
303 assert!(result.contains("Paragraph"));
304 }
305
306 #[test]
307 fn test_default_format() {
308 let tool = LinkFetchTool::new();
309 assert_eq!(tool.default_format, FetchFormat::Text);
310
311 let tool = LinkFetchTool::new().with_default_format(FetchFormat::Markdown);
312 assert_eq!(tool.default_format, FetchFormat::Markdown);
313 }
314
315 #[tokio::test]
316 async fn test_link_fetch_blocked_url() {
317 let tool = LinkFetchTool::new();
318 let ctx = ToolContext::new(());
319 let input = json!({ "url": "http://localhost:8080" });
320
321 let result = Tool::<()>::execute(&tool, &ctx, input).await;
322 assert!(result.is_ok());
323
324 let tool_result = result.expect("Should succeed");
325 assert!(!tool_result.success);
326 assert!(
327 tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
328 );
329 }
330
331 #[tokio::test]
332 async fn test_link_fetch_missing_url() {
333 let tool = LinkFetchTool::new();
334 let ctx = ToolContext::new(());
335 let input = json!({});
336
337 let result = Tool::<()>::execute(&tool, &ctx, input).await;
338 assert!(result.is_err());
339 assert!(result.unwrap_err().to_string().contains("url"));
340 }
341
342 #[tokio::test]
343 async fn test_link_fetch_invalid_url() {
344 let tool = LinkFetchTool::new();
345 let ctx = ToolContext::new(());
346 let input = json!({ "url": "not-a-valid-url" });
347
348 let result = Tool::<()>::execute(&tool, &ctx, input).await;
349 assert!(result.is_ok());
350
351 let tool_result = result.expect("Should succeed");
352 assert!(!tool_result.success);
353 assert!(tool_result.output.contains("Invalid URL"));
354 }
355
356 #[test]
357 fn test_with_validator() {
358 let validator = UrlValidator::new().with_allow_http();
359 let _tool = LinkFetchTool::new().with_validator(validator);
360 }
362
363 #[test]
364 fn test_format_name() {
365 assert_eq!(format_name(FetchFormat::Text), "text");
366 assert_eq!(format_name(FetchFormat::Markdown), "markdown");
367 }
368}