Skip to main content

rustant_tools/
web.rs

1//! Web tools: search, fetch, and document reading.
2//!
3//! Lightweight web access tools that work without browser automation.
4//! - `web_search`: Search the web using DuckDuckGo instant answers (privacy-first).
5//! - `web_fetch`: Fetch a URL and extract readable text content.
6//! - `document_read`: Read PDF and text documents from the local filesystem.
7
8use crate::registry::Tool;
9use async_trait::async_trait;
10use rustant_core::error::ToolError;
11use rustant_core::types::{RiskLevel, ToolOutput};
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14
15// ---------------------------------------------------------------------------
16// WebSearchTool
17// ---------------------------------------------------------------------------
18
19/// Search the web using DuckDuckGo instant answers API.
20///
21/// Returns structured results with titles, snippets, and URLs.
22/// Privacy-first: queries go directly to DuckDuckGo, never through a third party.
23#[derive(Default)]
24pub struct WebSearchTool;
25
26impl WebSearchTool {
27    pub fn new() -> Self {
28        Self
29    }
30}
31
32#[async_trait]
33impl Tool for WebSearchTool {
34    fn name(&self) -> &str {
35        "web_search"
36    }
37
38    fn description(&self) -> &str {
39        "Search the web for information. Returns titles, snippets, and URLs from search results. \
40         Use this to look up documentation, find solutions to errors, or research topics."
41    }
42
43    fn parameters_schema(&self) -> serde_json::Value {
44        serde_json::json!({
45            "type": "object",
46            "properties": {
47                "query": {
48                    "type": "string",
49                    "description": "The search query"
50                },
51                "max_results": {
52                    "type": "integer",
53                    "description": "Maximum number of results to return (default: 5, max: 10)",
54                    "default": 5
55                }
56            },
57            "required": ["query"]
58        })
59    }
60
61    fn risk_level(&self) -> RiskLevel {
62        RiskLevel::ReadOnly
63    }
64
65    fn timeout(&self) -> Duration {
66        Duration::from_secs(15)
67    }
68
69    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
70        let query = args.get("query").and_then(|v| v.as_str()).ok_or_else(|| {
71            ToolError::InvalidArguments {
72                name: "web_search".into(),
73                reason: "Missing required parameter: query".into(),
74            }
75        })?;
76
77        let max_results = args
78            .get("max_results")
79            .and_then(|v| v.as_u64())
80            .unwrap_or(5)
81            .min(10) as usize;
82
83        // Use DuckDuckGo HTML search (no API key required, privacy-first)
84        let client = reqwest::Client::builder()
85            .timeout(Duration::from_secs(10))
86            .user_agent("Rustant/1.0")
87            .build()
88            .map_err(|e| ToolError::ExecutionFailed {
89                name: "web_search".into(),
90                message: format!("Failed to create HTTP client: {}", e),
91            })?;
92
93        // Use DuckDuckGo instant answer API
94        let url = format!(
95            "https://api.duckduckgo.com/?q={}&format=json&no_html=1&skip_disambig=1",
96            urlencoding::encode(query)
97        );
98
99        let response = client
100            .get(&url)
101            .send()
102            .await
103            .map_err(|e| ToolError::ExecutionFailed {
104                name: "web_search".into(),
105                message: format!("Search request failed: {}", e),
106            })?;
107
108        let body: serde_json::Value =
109            response
110                .json()
111                .await
112                .map_err(|e| ToolError::ExecutionFailed {
113                    name: "web_search".into(),
114                    message: format!("Failed to parse search response: {}", e),
115                })?;
116
117        let mut results = Vec::new();
118
119        // Extract abstract (main answer)
120        if let Some(abstract_text) = body.get("AbstractText").and_then(|v| v.as_str())
121            && !abstract_text.is_empty()
122        {
123            let source = body
124                .get("AbstractSource")
125                .and_then(|v| v.as_str())
126                .unwrap_or("Unknown");
127            let url = body
128                .get("AbstractURL")
129                .and_then(|v| v.as_str())
130                .unwrap_or("");
131            results.push(format!("[{}] {}\n  URL: {}", source, abstract_text, url));
132        }
133
134        // Extract related topics
135        if let Some(topics) = body.get("RelatedTopics").and_then(|v| v.as_array()) {
136            for topic in topics
137                .iter()
138                .take(max_results.saturating_sub(results.len()))
139            {
140                if let Some(text) = topic.get("Text").and_then(|v| v.as_str()) {
141                    let url = topic.get("FirstURL").and_then(|v| v.as_str()).unwrap_or("");
142                    results.push(format!("- {}\n  URL: {}", text, url));
143                }
144            }
145        }
146
147        // Extract results from Results array
148        if let Some(res_array) = body.get("Results").and_then(|v| v.as_array()) {
149            for result in res_array
150                .iter()
151                .take(max_results.saturating_sub(results.len()))
152            {
153                if let Some(text) = result.get("Text").and_then(|v| v.as_str()) {
154                    let url = result
155                        .get("FirstURL")
156                        .and_then(|v| v.as_str())
157                        .unwrap_or("");
158                    results.push(format!("- {}\n  URL: {}", text, url));
159                }
160            }
161        }
162
163        let content = if results.is_empty() {
164            format!(
165                "No instant answers found for \"{}\". Try refining your query or use web_fetch with a specific URL.",
166                query
167            )
168        } else {
169            format!(
170                "Search results for \"{}\":\n\n{}",
171                query,
172                results.join("\n\n")
173            )
174        };
175
176        Ok(ToolOutput::text(content))
177    }
178}
179
180// ---------------------------------------------------------------------------
181// WebFetchTool
182// ---------------------------------------------------------------------------
183
184/// Fetch a URL and extract readable text content.
185///
186/// Strips HTML tags and returns clean text. Much lighter than browser
187/// automation — no Chrome required.
188#[derive(Default)]
189pub struct WebFetchTool;
190
191impl WebFetchTool {
192    pub fn new() -> Self {
193        Self
194    }
195}
196
197#[async_trait]
198impl Tool for WebFetchTool {
199    fn name(&self) -> &str {
200        "web_fetch"
201    }
202
203    fn description(&self) -> &str {
204        "Fetch a web page URL and extract its text content. Returns the readable text \
205         from the page, stripped of HTML. Use this to read documentation, articles, or \
206         API references from the web."
207    }
208
209    fn parameters_schema(&self) -> serde_json::Value {
210        serde_json::json!({
211            "type": "object",
212            "properties": {
213                "url": {
214                    "type": "string",
215                    "description": "The URL to fetch"
216                },
217                "max_length": {
218                    "type": "integer",
219                    "description": "Maximum characters of content to return (default: 5000)",
220                    "default": 5000
221                }
222            },
223            "required": ["url"]
224        })
225    }
226
227    fn risk_level(&self) -> RiskLevel {
228        RiskLevel::ReadOnly
229    }
230
231    fn timeout(&self) -> Duration {
232        Duration::from_secs(30)
233    }
234
235    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
236        let url = args.get("url").and_then(|v| v.as_str()).ok_or_else(|| {
237            ToolError::InvalidArguments {
238                name: "web_fetch".into(),
239                reason: "Missing required parameter: url".into(),
240            }
241        })?;
242
243        let max_length = args
244            .get("max_length")
245            .and_then(|v| v.as_u64())
246            .unwrap_or(5000) as usize;
247
248        // Validate URL
249        if !url.starts_with("http://") && !url.starts_with("https://") {
250            return Err(ToolError::InvalidArguments {
251                name: "web_fetch".into(),
252                reason: "URL must start with http:// or https://".into(),
253            });
254        }
255
256        let client = reqwest::Client::builder()
257            .timeout(Duration::from_secs(15))
258            .user_agent("Rustant/1.0")
259            .redirect(reqwest::redirect::Policy::limited(5))
260            .build()
261            .map_err(|e| ToolError::ExecutionFailed {
262                name: "web_fetch".into(),
263                message: format!("Failed to create HTTP client: {}", e),
264            })?;
265
266        let response = client
267            .get(url)
268            .send()
269            .await
270            .map_err(|e| ToolError::ExecutionFailed {
271                name: "web_fetch".into(),
272                message: format!("Fetch failed: {}", e),
273            })?;
274
275        let status = response.status();
276        if !status.is_success() {
277            return Ok(ToolOutput::text(format!(
278                "HTTP {} for URL: {}",
279                status, url
280            )));
281        }
282
283        let content_type = response
284            .headers()
285            .get("content-type")
286            .and_then(|v| v.to_str().ok())
287            .unwrap_or("")
288            .to_string();
289
290        let body = response
291            .text()
292            .await
293            .map_err(|e| ToolError::ExecutionFailed {
294                name: "web_fetch".into(),
295                message: format!("Failed to read response body: {}", e),
296            })?;
297
298        // Extract text from HTML
299        let text =
300            if content_type.contains("text/html") || content_type.contains("application/xhtml") {
301                extract_text_from_html(&body)
302            } else {
303                // Plain text or other formats — return as-is
304                body
305            };
306
307        // Truncate if needed
308        let text = if text.len() > max_length {
309            format!(
310                "{}...\n\n[Truncated at {} characters. Use max_length to see more.]",
311                &text[..max_length],
312                max_length
313            )
314        } else {
315            text
316        };
317
318        let content = format!("Content from {}:\n\n{}", url, text);
319
320        Ok(ToolOutput::text(content))
321    }
322}
323
324/// Simple HTML-to-text extraction.
325///
326/// Strips HTML tags and extracts readable content. Handles common elements
327/// like paragraphs, headings, list items, and code blocks.
328fn extract_text_from_html(html: &str) -> String {
329    let mut text = String::new();
330    let mut in_tag = false;
331    let mut in_script = false;
332    let mut in_style = false;
333    let mut tag_name = String::new();
334    let mut building_tag = false;
335
336    for ch in html.chars() {
337        if ch == '<' {
338            in_tag = true;
339            building_tag = true;
340            tag_name.clear();
341            continue;
342        }
343        if ch == '>' {
344            in_tag = false;
345            building_tag = false;
346
347            let tag_lower = tag_name.to_lowercase();
348            if tag_lower == "script" {
349                in_script = true;
350            } else if tag_lower == "/script" {
351                in_script = false;
352            } else if tag_lower == "style" {
353                in_style = true;
354            } else if tag_lower == "/style" {
355                in_style = false;
356            }
357
358            // Add newlines for block elements
359            if tag_lower.starts_with("p")
360                || tag_lower.starts_with("/p")
361                || tag_lower.starts_with("br")
362                || tag_lower.starts_with("div")
363                || tag_lower.starts_with("/div")
364                || tag_lower.starts_with("h1")
365                || tag_lower.starts_with("h2")
366                || tag_lower.starts_with("h3")
367                || tag_lower.starts_with("h4")
368                || tag_lower.starts_with("h5")
369                || tag_lower.starts_with("h6")
370                || tag_lower.starts_with("/h")
371                || tag_lower.starts_with("li")
372                || tag_lower.starts_with("tr")
373            {
374                text.push('\n');
375            }
376
377            continue;
378        }
379        if in_tag {
380            if building_tag && (ch.is_alphanumeric() || ch == '/') {
381                tag_name.push(ch);
382            } else {
383                building_tag = false;
384            }
385            continue;
386        }
387        if in_script || in_style {
388            continue;
389        }
390        text.push(ch);
391    }
392
393    // Decode common HTML entities
394    let text = text
395        .replace("&amp;", "&")
396        .replace("&lt;", "<")
397        .replace("&gt;", ">")
398        .replace("&quot;", "\"")
399        .replace("&#39;", "'")
400        .replace("&nbsp;", " ");
401
402    // Clean up whitespace: collapse multiple blank lines
403    let mut lines: Vec<&str> = text.lines().map(|l| l.trim()).collect();
404    lines.dedup_by(|a, b| a.is_empty() && b.is_empty());
405    let result: String = lines
406        .into_iter()
407        .filter(|l| !l.is_empty())
408        .collect::<Vec<_>>()
409        .join("\n");
410
411    result
412}
413
414// ---------------------------------------------------------------------------
415// DocumentReadTool
416// ---------------------------------------------------------------------------
417
418/// Read documents from the local filesystem (plain text and common formats).
419///
420/// Supports: .txt, .md, .csv, .json, .yaml, .toml, .xml, .log files.
421/// For PDF support, the `pdf-extract` crate would be needed (not included
422/// by default to keep dependencies minimal).
423pub struct DocumentReadTool {
424    workspace: PathBuf,
425}
426
427impl DocumentReadTool {
428    pub fn new(workspace: PathBuf) -> Self {
429        Self { workspace }
430    }
431
432    fn resolve_path(&self, path: &str) -> Result<PathBuf, ToolError> {
433        let resolved = if Path::new(path).is_absolute() {
434            PathBuf::from(path)
435        } else {
436            self.workspace.join(path)
437        };
438
439        let canonical = resolved
440            .canonicalize()
441            .map_err(|e| ToolError::ExecutionFailed {
442                name: "document_read".into(),
443                message: format!("Path resolution failed: {}", e),
444            })?;
445
446        // Allow reading outside workspace for documents (e.g., ~/Downloads/*.pdf)
447        // but still validate the path exists
448        if !canonical.exists() {
449            return Err(ToolError::ExecutionFailed {
450                name: "document_read".into(),
451                message: format!("File not found: {}", path),
452            });
453        }
454
455        Ok(canonical)
456    }
457}
458
459#[async_trait]
460impl Tool for DocumentReadTool {
461    fn name(&self) -> &str {
462        "document_read"
463    }
464
465    fn description(&self) -> &str {
466        "Read a document file and extract its text content. Supports text-based formats: \
467         .txt, .md, .csv, .json, .yaml, .yml, .toml, .xml, .log, .cfg, .ini, .html. \
468         Returns the file content as text."
469    }
470
471    fn parameters_schema(&self) -> serde_json::Value {
472        serde_json::json!({
473            "type": "object",
474            "properties": {
475                "path": {
476                    "type": "string",
477                    "description": "Path to the document file (relative to workspace or absolute)"
478                },
479                "max_length": {
480                    "type": "integer",
481                    "description": "Maximum characters to return (default: 10000)",
482                    "default": 10000
483                }
484            },
485            "required": ["path"]
486        })
487    }
488
489    fn risk_level(&self) -> RiskLevel {
490        RiskLevel::ReadOnly
491    }
492
493    fn timeout(&self) -> Duration {
494        Duration::from_secs(10)
495    }
496
497    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
498        let path_str = args.get("path").and_then(|v| v.as_str()).ok_or_else(|| {
499            ToolError::InvalidArguments {
500                name: "document_read".into(),
501                reason: "Missing required parameter: path".into(),
502            }
503        })?;
504
505        let max_length = args
506            .get("max_length")
507            .and_then(|v| v.as_u64())
508            .unwrap_or(10000) as usize;
509
510        let path = self.resolve_path(path_str)?;
511
512        let extension = path
513            .extension()
514            .and_then(|e| e.to_str())
515            .unwrap_or("")
516            .to_lowercase();
517
518        // Validate supported extensions
519        let supported = [
520            "txt",
521            "md",
522            "csv",
523            "json",
524            "yaml",
525            "yml",
526            "toml",
527            "xml",
528            "log",
529            "cfg",
530            "ini",
531            "html",
532            "htm",
533            "rst",
534            "adoc",
535            "tex",
536            "rtf",
537            "conf",
538            "properties",
539            "env",
540        ];
541
542        if !supported.contains(&extension.as_str()) {
543            return Err(ToolError::InvalidArguments {
544                name: "document_read".into(),
545                reason: format!(
546                    "Unsupported file format '.{}'. Supported: {}",
547                    extension,
548                    supported.join(", ")
549                ),
550            });
551        }
552
553        // Read file
554        let content = std::fs::read_to_string(&path).map_err(|e| ToolError::ExecutionFailed {
555            name: "document_read".into(),
556            message: format!("Failed to read file: {}", e),
557        })?;
558
559        // For HTML files, extract text
560        let text = if extension == "html" || extension == "htm" {
561            extract_text_from_html(&content)
562        } else {
563            content
564        };
565
566        // Truncate if needed
567        let text = if text.len() > max_length {
568            format!(
569                "{}...\n\n[Truncated at {} characters. Use max_length to see more.]",
570                &text[..max_length],
571                max_length
572            )
573        } else {
574            text
575        };
576
577        let file_size = std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
578
579        let content = format!(
580            "Document: {} ({} bytes, .{}):\n\n{}",
581            path.display(),
582            file_size,
583            extension,
584            text
585        );
586
587        Ok(ToolOutput::text(content))
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use tempfile::TempDir;
595
596    #[test]
597    fn test_extract_text_from_html() {
598        let html = r#"
599        <html>
600        <head><title>Test</title></head>
601        <body>
602            <h1>Hello World</h1>
603            <p>This is a <b>test</b> paragraph.</p>
604            <script>var x = 1;</script>
605            <style>.foo { color: red; }</style>
606            <ul>
607                <li>Item 1</li>
608                <li>Item 2</li>
609            </ul>
610        </body>
611        </html>"#;
612
613        let text = extract_text_from_html(html);
614        assert!(text.contains("Hello World"));
615        assert!(text.contains("This is a test paragraph."));
616        assert!(text.contains("Item 1"));
617        assert!(text.contains("Item 2"));
618        assert!(!text.contains("var x = 1"));
619        assert!(!text.contains("color: red"));
620    }
621
622    #[test]
623    fn test_extract_text_html_entities() {
624        let html = "<p>A &amp; B &lt; C &gt; D &quot;E&quot;</p>";
625        let text = extract_text_from_html(html);
626        assert!(text.contains("A & B < C > D \"E\""));
627    }
628
629    #[tokio::test]
630    async fn test_web_search_tool_schema() {
631        let tool = WebSearchTool::new();
632        assert_eq!(tool.name(), "web_search");
633        assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
634        let schema = tool.parameters_schema();
635        assert!(schema.get("properties").is_some());
636        assert!(schema["properties"].get("query").is_some());
637    }
638
639    #[tokio::test]
640    async fn test_web_fetch_tool_schema() {
641        let tool = WebFetchTool::new();
642        assert_eq!(tool.name(), "web_fetch");
643        assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
644        let schema = tool.parameters_schema();
645        assert!(schema["properties"].get("url").is_some());
646    }
647
648    #[tokio::test]
649    async fn test_web_fetch_invalid_url() {
650        let tool = WebFetchTool::new();
651        let result = tool.execute(serde_json::json!({"url": "not-a-url"})).await;
652        assert!(result.is_err());
653    }
654
655    #[tokio::test]
656    async fn test_document_read_tool_schema() {
657        let dir = TempDir::new().unwrap();
658        let tool = DocumentReadTool::new(dir.path().to_path_buf());
659        assert_eq!(tool.name(), "document_read");
660        assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
661    }
662
663    #[tokio::test]
664    async fn test_document_read_text_file() {
665        let dir = TempDir::new().unwrap();
666        let file_path = dir.path().join("test.txt");
667        std::fs::write(&file_path, "Hello, this is a test document.").unwrap();
668
669        let tool = DocumentReadTool::new(dir.path().to_path_buf());
670        let result = tool
671            .execute(serde_json::json!({"path": file_path.to_str().unwrap()}))
672            .await
673            .unwrap();
674
675        assert!(result.content.contains("Hello, this is a test document."));
676    }
677
678    #[tokio::test]
679    async fn test_document_read_markdown_file() {
680        let dir = TempDir::new().unwrap();
681        let file_path = dir.path().join("readme.md");
682        std::fs::write(&file_path, "# Title\n\nSome content.").unwrap();
683
684        let tool = DocumentReadTool::new(dir.path().to_path_buf());
685        let result = tool
686            .execute(serde_json::json!({"path": file_path.to_str().unwrap()}))
687            .await
688            .unwrap();
689
690        assert!(result.content.contains("# Title"));
691    }
692
693    #[tokio::test]
694    async fn test_document_read_unsupported_format() {
695        let dir = TempDir::new().unwrap();
696        let file_path = dir.path().join("binary.exe");
697        std::fs::write(&file_path, "fake binary").unwrap();
698
699        let tool = DocumentReadTool::new(dir.path().to_path_buf());
700        let result = tool
701            .execute(serde_json::json!({"path": file_path.to_str().unwrap()}))
702            .await;
703
704        assert!(result.is_err());
705    }
706
707    #[tokio::test]
708    async fn test_document_read_truncation() {
709        let dir = TempDir::new().unwrap();
710        let file_path = dir.path().join("long.txt");
711        let long_text = "a".repeat(20000);
712        std::fs::write(&file_path, &long_text).unwrap();
713
714        let tool = DocumentReadTool::new(dir.path().to_path_buf());
715        let result = tool
716            .execute(serde_json::json!({
717                "path": file_path.to_str().unwrap(),
718                "max_length": 100
719            }))
720            .await
721            .unwrap();
722
723        assert!(result.content.contains("Truncated"));
724    }
725
726    #[tokio::test]
727    async fn test_document_read_html_file() {
728        let dir = TempDir::new().unwrap();
729        let file_path = dir.path().join("page.html");
730        std::fs::write(&file_path, "<h1>Title</h1><p>Content here.</p>").unwrap();
731
732        let tool = DocumentReadTool::new(dir.path().to_path_buf());
733        let result = tool
734            .execute(serde_json::json!({"path": file_path.to_str().unwrap()}))
735            .await
736            .unwrap();
737
738        assert!(result.content.contains("Title"));
739        assert!(result.content.contains("Content here."));
740        // Should not contain HTML tags
741        assert!(!result.content.contains("<h1>"));
742    }
743
744    #[tokio::test]
745    async fn test_document_read_missing_param() {
746        let dir = TempDir::new().unwrap();
747        let tool = DocumentReadTool::new(dir.path().to_path_buf());
748        let result = tool.execute(serde_json::json!({})).await;
749        assert!(result.is_err());
750    }
751
752    #[tokio::test]
753    async fn test_web_search_missing_query() {
754        let tool = WebSearchTool::new();
755        let result = tool.execute(serde_json::json!({})).await;
756        assert!(result.is_err());
757    }
758
759    // --- Security-oriented tests ---
760
761    #[tokio::test]
762    async fn test_web_fetch_rejects_file_protocol() {
763        let tool = WebFetchTool::new();
764        let result = tool
765            .execute(serde_json::json!({"url": "file:///etc/passwd"}))
766            .await;
767        assert!(result.is_err());
768        match result.unwrap_err() {
769            ToolError::InvalidArguments { reason, .. } => {
770                assert!(
771                    reason.contains("http://") || reason.contains("https://"),
772                    "Error should mention valid protocols, got: {}",
773                    reason
774                );
775            }
776            e => panic!("Expected InvalidArguments, got: {:?}", e),
777        }
778    }
779
780    #[tokio::test]
781    async fn test_web_fetch_rejects_javascript_protocol() {
782        let tool = WebFetchTool::new();
783        let result = tool
784            .execute(serde_json::json!({"url": "javascript:alert(1)"}))
785            .await;
786        assert!(result.is_err());
787    }
788
789    #[tokio::test]
790    async fn test_web_fetch_rejects_data_protocol() {
791        let tool = WebFetchTool::new();
792        let result = tool
793            .execute(serde_json::json!({"url": "data:text/html,<script>alert(1)</script>"}))
794            .await;
795        assert!(result.is_err());
796    }
797
798    #[tokio::test]
799    async fn test_web_fetch_rejects_ftp_protocol() {
800        let tool = WebFetchTool::new();
801        let result = tool
802            .execute(serde_json::json!({"url": "ftp://example.com/file"}))
803            .await;
804        assert!(result.is_err());
805    }
806
807    #[tokio::test]
808    async fn test_web_fetch_missing_url_param() {
809        let tool = WebFetchTool::new();
810        let result = tool.execute(serde_json::json!({})).await;
811        assert!(result.is_err());
812        match result.unwrap_err() {
813            ToolError::InvalidArguments { name, reason } => {
814                assert_eq!(name, "web_fetch");
815                assert!(reason.contains("url"));
816            }
817            e => panic!("Expected InvalidArguments, got: {:?}", e),
818        }
819    }
820
821    #[test]
822    fn test_web_search_timeout() {
823        let tool = WebSearchTool::new();
824        assert_eq!(tool.timeout(), Duration::from_secs(15));
825    }
826
827    #[test]
828    fn test_web_fetch_timeout() {
829        let tool = WebFetchTool::new();
830        assert_eq!(tool.timeout(), Duration::from_secs(30));
831    }
832
833    #[test]
834    fn test_document_read_timeout() {
835        let dir = TempDir::new().unwrap();
836        let tool = DocumentReadTool::new(dir.path().to_path_buf());
837        assert_eq!(tool.timeout(), Duration::from_secs(10));
838    }
839
840    #[tokio::test]
841    async fn test_document_read_nonexistent_file() {
842        let dir = TempDir::new().unwrap();
843        let tool = DocumentReadTool::new(dir.path().to_path_buf());
844        let result = tool
845            .execute(serde_json::json!({"path": "/nonexistent/file.txt"}))
846            .await;
847        assert!(result.is_err());
848    }
849
850    #[test]
851    fn test_extract_text_from_html_nested_scripts() {
852        // Ensure nested script/style tags don't leak content
853        let html = "<script>alert('xss')</script><p>safe</p><style>body{}</style>";
854        let text = extract_text_from_html(html);
855        assert!(!text.contains("alert"));
856        assert!(!text.contains("xss"));
857        assert!(!text.contains("body{}"));
858        assert!(text.contains("safe"));
859    }
860
861    #[test]
862    fn test_extract_text_from_html_empty() {
863        let text = extract_text_from_html("");
864        assert!(text.is_empty());
865    }
866
867    #[test]
868    fn test_extract_text_from_html_plain_text() {
869        let text = extract_text_from_html("Just plain text with no tags");
870        assert!(text.contains("Just plain text with no tags"));
871    }
872
873    #[tokio::test]
874    async fn test_document_read_json_file() {
875        let dir = TempDir::new().unwrap();
876        let file_path = dir.path().join("data.json");
877        std::fs::write(&file_path, r#"{"key": "value"}"#).unwrap();
878
879        let tool = DocumentReadTool::new(dir.path().to_path_buf());
880        let result = tool
881            .execute(serde_json::json!({"path": file_path.to_str().unwrap()}))
882            .await
883            .unwrap();
884
885        assert!(result.content.contains(r#""key": "value""#));
886    }
887
888    #[tokio::test]
889    async fn test_document_read_toml_file() {
890        let dir = TempDir::new().unwrap();
891        let file_path = dir.path().join("config.toml");
892        std::fs::write(&file_path, "[section]\nkey = \"value\"").unwrap();
893
894        let tool = DocumentReadTool::new(dir.path().to_path_buf());
895        let result = tool
896            .execute(serde_json::json!({"path": file_path.to_str().unwrap()}))
897            .await
898            .unwrap();
899
900        assert!(result.content.contains("key = \"value\""));
901    }
902
903    #[tokio::test]
904    async fn test_web_search_schema_has_required() {
905        let tool = WebSearchTool::new();
906        let schema = tool.parameters_schema();
907        let required = schema["required"].as_array().unwrap();
908        assert!(required.contains(&serde_json::json!("query")));
909    }
910
911    #[tokio::test]
912    async fn test_web_fetch_schema_has_required() {
913        let tool = WebFetchTool::new();
914        let schema = tool.parameters_schema();
915        let required = schema["required"].as_array().unwrap();
916        assert!(required.contains(&serde_json::json!("url")));
917    }
918
919    #[tokio::test]
920    async fn test_document_read_schema_has_required() {
921        let dir = TempDir::new().unwrap();
922        let tool = DocumentReadTool::new(dir.path().to_path_buf());
923        let schema = tool.parameters_schema();
924        let required = schema["required"].as_array().unwrap();
925        assert!(required.contains(&serde_json::json!("path")));
926    }
927}