1use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result, bail};
6use serde_json::{Value, json};
7use std::time::Duration;
8
9use super::security::UrlValidator;
10
11const MAX_CONTENT_SIZE: usize = 1024 * 1024;
13
14const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
16
17#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
19pub enum FetchFormat {
20 #[default]
22 Text,
23 Markdown,
25}
26
27impl FetchFormat {
28 fn from_str(s: &str) -> Option<Self> {
29 match s.to_lowercase().as_str() {
30 "text" => Some(Self::Text),
31 "markdown" | "md" => Some(Self::Markdown),
32 _ => None,
33 }
34 }
35}
36
37pub struct LinkFetchTool {
53 client: reqwest::Client,
54 validator: UrlValidator,
55 default_format: FetchFormat,
56}
57
58impl Default for LinkFetchTool {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl LinkFetchTool {
65 #[must_use]
71 pub fn new() -> Self {
72 let client = reqwest::Client::builder()
73 .timeout(DEFAULT_TIMEOUT)
74 .user_agent("Mozilla/5.0 (compatible; AgentSDK/1.0)")
75 .build()
76 .expect("Failed to build HTTP client");
77
78 Self {
79 client,
80 validator: UrlValidator::new(),
81 default_format: FetchFormat::Text,
82 }
83 }
84
85 #[must_use]
87 pub fn with_validator(mut self, validator: UrlValidator) -> Self {
88 self.validator = validator;
89 self
90 }
91
92 #[must_use]
94 pub fn with_client(mut self, client: reqwest::Client) -> Self {
95 self.client = client;
96 self
97 }
98
99 #[must_use]
101 pub const fn with_default_format(mut self, format: FetchFormat) -> Self {
102 self.default_format = format;
103 self
104 }
105
106 async fn fetch_url(&self, url_str: &str, format: FetchFormat) -> Result<String> {
108 let url = self.validator.validate(url_str)?;
110
111 let response = self
113 .client
114 .get(url.as_str())
115 .send()
116 .await
117 .context("Failed to fetch URL")?;
118
119 if !response.status().is_success() {
121 bail!("HTTP error: {}", response.status());
122 }
123
124 if let Some(len) = response.content_length()
126 && len > MAX_CONTENT_SIZE as u64
127 {
128 bail!("Content too large: {len} bytes (max {MAX_CONTENT_SIZE} bytes)");
129 }
130
131 let content_type = response
133 .headers()
134 .get(reqwest::header::CONTENT_TYPE)
135 .and_then(|v| v.to_str().ok())
136 .unwrap_or("text/html")
137 .to_string();
138
139 let bytes = response
141 .bytes()
142 .await
143 .context("Failed to read response body")?;
144
145 if bytes.len() > MAX_CONTENT_SIZE {
146 bail!(
147 "Content too large: {} bytes (max {} bytes)",
148 bytes.len(),
149 MAX_CONTENT_SIZE
150 );
151 }
152
153 let html = String::from_utf8_lossy(&bytes);
155
156 if content_type.contains("text/html") || content_type.contains("application/xhtml") {
158 Ok(convert_html(&html, format))
159 } else if content_type.contains("text/plain") {
160 Ok(html.into_owned())
161 } else {
162 Ok(html.into_owned())
164 }
165 }
166}
167
168fn convert_html(html: &str, format: FetchFormat) -> String {
170 let result = match format {
171 FetchFormat::Text => {
172 html2text::from_read(html.as_bytes(), 80)
174 }
175 FetchFormat::Markdown => {
176 html2text::from_read(html.as_bytes(), 80)
178 }
179 };
180 result.unwrap_or_else(|_| html.to_string())
181}
182
183impl<Ctx> Tool<Ctx> for LinkFetchTool
184where
185 Ctx: Send + Sync + 'static,
186{
187 type Name = PrimitiveToolName;
188
189 fn name(&self) -> PrimitiveToolName {
190 PrimitiveToolName::LinkFetch
191 }
192
193 fn display_name(&self) -> &'static str {
194 "Fetch URL"
195 }
196
197 fn description(&self) -> &'static str {
198 "Fetch and read web page content. Returns the page content as text or markdown. \
199 Includes SSRF protection to prevent access to internal resources."
200 }
201
202 fn input_schema(&self) -> Value {
203 json!({
204 "type": "object",
205 "properties": {
206 "url": {
207 "type": "string",
208 "description": "The URL to fetch (must be HTTPS)"
209 },
210 "format": {
211 "type": "string",
212 "enum": ["text", "markdown"],
213 "description": "Output format (default: text)"
214 }
215 },
216 "required": ["url"]
217 })
218 }
219
220 fn tier(&self) -> ToolTier {
221 ToolTier::Observe
223 }
224
225 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
226 let url = input
227 .get("url")
228 .and_then(Value::as_str)
229 .context("Missing 'url' parameter")?;
230
231 let format = input
232 .get("format")
233 .and_then(Value::as_str)
234 .and_then(FetchFormat::from_str)
235 .unwrap_or(self.default_format);
236
237 match self.fetch_url(url, format).await {
238 Ok(content) => Ok(ToolResult {
239 success: true,
240 output: content,
241 data: Some(json!({ "url": url, "format": format_name(format) })),
242 documents: Vec::new(),
243 duration_ms: None,
244 }),
245 Err(e) => Ok(ToolResult {
246 success: false,
247 output: format!("Failed to fetch URL: {e}"),
248 data: Some(json!({ "url": url, "error": e.to_string() })),
249 documents: Vec::new(),
250 duration_ms: None,
251 }),
252 }
253 }
254}
255
256const fn format_name(format: FetchFormat) -> &'static str {
258 match format {
259 FetchFormat::Text => "text",
260 FetchFormat::Markdown => "markdown",
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_link_fetch_tool_metadata() {
270 let tool = LinkFetchTool::new();
271
272 assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::LinkFetch);
273 assert!(Tool::<()>::description(&tool).contains("Fetch"));
274 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
275 }
276
277 #[test]
278 fn test_link_fetch_tool_input_schema() {
279 let tool = LinkFetchTool::new();
280
281 let schema = Tool::<()>::input_schema(&tool);
282 assert_eq!(schema["type"], "object");
283 assert!(schema["properties"]["url"].is_object());
284 assert!(schema["properties"]["format"].is_object());
285 assert!(
286 schema["required"]
287 .as_array()
288 .is_some_and(|arr| arr.iter().any(|v| v == "url"))
289 );
290 }
291
292 #[test]
293 fn test_format_from_str() {
294 assert_eq!(FetchFormat::from_str("text"), Some(FetchFormat::Text));
295 assert_eq!(FetchFormat::from_str("TEXT"), Some(FetchFormat::Text));
296 assert_eq!(
297 FetchFormat::from_str("markdown"),
298 Some(FetchFormat::Markdown)
299 );
300 assert_eq!(FetchFormat::from_str("md"), Some(FetchFormat::Markdown));
301 assert_eq!(FetchFormat::from_str("invalid"), None);
302 }
303
304 #[test]
305 fn test_convert_html_text() {
306 let html = "<html><body><h1>Title</h1><p>Paragraph</p></body></html>";
307 let result = convert_html(html, FetchFormat::Text);
308 assert!(result.contains("Title"));
309 assert!(result.contains("Paragraph"));
310 }
311
312 #[test]
313 fn test_default_format() {
314 let tool = LinkFetchTool::new();
315 assert_eq!(tool.default_format, FetchFormat::Text);
316
317 let tool = LinkFetchTool::new().with_default_format(FetchFormat::Markdown);
318 assert_eq!(tool.default_format, FetchFormat::Markdown);
319 }
320
321 #[tokio::test]
322 async fn test_link_fetch_blocked_url() {
323 let tool = LinkFetchTool::new();
324 let ctx = ToolContext::new(());
325 let input = json!({ "url": "http://localhost:8080" });
326
327 let result = Tool::<()>::execute(&tool, &ctx, input).await;
328 assert!(result.is_ok());
329
330 let tool_result = result.expect("Should succeed");
331 assert!(!tool_result.success);
332 assert!(
333 tool_result.output.contains("HTTPS required") || tool_result.output.contains("blocked")
334 );
335 }
336
337 #[tokio::test]
338 async fn test_link_fetch_missing_url() {
339 let tool = LinkFetchTool::new();
340 let ctx = ToolContext::new(());
341 let input = json!({});
342
343 let result = Tool::<()>::execute(&tool, &ctx, input).await;
344 assert!(result.is_err());
345 assert!(result.unwrap_err().to_string().contains("url"));
346 }
347
348 #[tokio::test]
349 async fn test_link_fetch_invalid_url() {
350 let tool = LinkFetchTool::new();
351 let ctx = ToolContext::new(());
352 let input = json!({ "url": "not-a-valid-url" });
353
354 let result = Tool::<()>::execute(&tool, &ctx, input).await;
355 assert!(result.is_ok());
356
357 let tool_result = result.expect("Should succeed");
358 assert!(!tool_result.success);
359 assert!(tool_result.output.contains("Invalid URL"));
360 }
361
362 #[test]
363 fn test_with_validator() {
364 let validator = UrlValidator::new().with_allow_http();
365 let _tool = LinkFetchTool::new().with_validator(validator);
366 }
368
369 #[test]
370 fn test_format_name() {
371 assert_eq!(format_name(FetchFormat::Text), "text");
372 assert_eq!(format_name(FetchFormat::Markdown), "markdown");
373 }
374}