Skip to main content

ai_agent/tools/
web_search.rs

1use crate::types::*;
2use crate::utils::http::get_user_agent;
3use regex::Regex;
4use reqwest::Client;
5
6pub struct WebSearchTool {
7    client: Client,
8}
9
10impl WebSearchTool {
11    pub fn new() -> Self {
12        let client = Client::builder()
13            .timeout(std::time::Duration::from_secs(15))
14            .user_agent(get_user_agent())
15            .build()
16            .expect("Failed to create HTTP client");
17        Self { client }
18    }
19
20    pub fn name(&self) -> &str {
21        "WebSearch"
22    }
23
24    pub fn description(&self) -> &str {
25        "Search the web for information. Returns search results with titles, URLs, and snippets."
26    }
27
28    pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
29        "WebSearch".to_string()
30    }
31
32    pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
33        input.and_then(|inp| inp["query"].as_str().map(String::from))
34    }
35
36    pub fn render_tool_result_message(
37        &self,
38        content: &serde_json::Value,
39    ) -> Option<String> {
40        let text = content["content"].as_str()?;
41        let lines = text.lines().count();
42        Some(format!("{} lines", lines))
43    }
44
45    pub fn input_schema(&self) -> ToolInputSchema {
46        ToolInputSchema {
47            schema_type: "object".to_string(),
48            properties: serde_json::json!({
49                "query": {
50                    "type": "string",
51                    "description": "The search query"
52                },
53                "num_results": {
54                    "type": "number",
55                    "description": "Number of results to return (default: 5)"
56                }
57            }),
58            required: Some(vec!["query".to_string()]),
59        }
60    }
61
62    pub async fn execute(
63        &self,
64        input: serde_json::Value,
65        _context: &ToolContext,
66    ) -> Result<ToolResult, crate::error::AgentError> {
67        let query = input["query"]
68            .as_str()
69            .ok_or_else(|| crate::error::AgentError::Tool("query is required".to_string()))?;
70
71        let num_results = input["num_results"].as_u64().unwrap_or(5) as usize;
72
73        // Use DuckDuckGo HTML search
74        let encoded = urlencoding::encode(query);
75        let url = format!("https://html.duckduckgo.com/html/?q={}", encoded);
76
77        let response = self
78            .client
79            .get(&url)
80            .send()
81            .await
82            .map_err(|e| crate::error::AgentError::Tool(format!("Search error: {}", e)))?;
83
84        if !response.status().is_success() {
85            return Ok(ToolResult {
86                result_type: "text".to_string(),
87                tool_use_id: "".to_string(),
88                content: format!("Search failed: HTTP {}", response.status().as_u16()),
89                is_error: Some(true),
90                was_persisted: None,
91            });
92        }
93
94        let html = response.text().await.map_err(|e| {
95            crate::error::AgentError::Tool(format!("Error reading search results: {}", e))
96        })?;
97
98        // Parse search results from DuckDuckGo HTML
99        let result_regex =
100            Regex::new(r#"<a rel="nofollow" class="result__a" href="([^"]*)"[^>]*>([\s\S]*?)</a>"#)
101                .unwrap();
102        let snippet_regex =
103            Regex::new(r#"<a class="result__snippet"[^>]*>([\s\S]*?)</a>"#).unwrap();
104
105        let mut links: Vec<(String, String)> = Vec::new();
106        for cap in result_regex.captures_iter(&html) {
107            if let (Some(href), Some(title)) = (cap.get(1), cap.get(2)) {
108                let href = href.as_str().to_string();
109                let title = title.as_str().replace("<[^>]+>", "").trim().to_string();
110                if !href.is_empty() && !title.is_empty() && !href.contains("duckduckgo.com") {
111                    links.push((title, href));
112                }
113            }
114        }
115
116        let mut snippets: Vec<String> = Vec::new();
117        for cap in snippet_regex.captures_iter(&html) {
118            if let Some(snippet) = cap.get(1) {
119                let snippet_text = snippet.as_str().replace("<[^>]+>", "").trim().to_string();
120                snippets.push(snippet_text);
121            }
122        }
123
124        let mut results: Vec<String> = Vec::new();
125        let num_results = std::cmp::min(num_results, links.len());
126
127        for i in 0..num_results {
128            let (title, url) = &links[i];
129            let mut entry = format!("{}. {}\n   {}", i + 1, title, url);
130            if let Some(snippet) = snippets.get(i) {
131                if !snippet.is_empty() {
132                    entry.push_str(&format!("\n   {}", snippet));
133                }
134            }
135            results.push(entry);
136        }
137
138        let content = if results.is_empty() {
139            format!("No results found for \"{}\"", query)
140        } else {
141            results.join("\n\n")
142        };
143
144        Ok(ToolResult {
145            result_type: "text".to_string(),
146            tool_use_id: "".to_string(),
147            content,
148            is_error: None,
149            was_persisted: None,
150        })
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn test_web_search_tool_name() {
160        let tool = WebSearchTool::new();
161        assert_eq!(tool.name(), "WebSearch");
162    }
163
164    #[test]
165    fn test_web_search_tool_description_contains_search() {
166        let tool = WebSearchTool::new();
167        assert!(tool.description().to_lowercase().contains("search"));
168    }
169
170    #[test]
171    fn test_web_search_tool_has_query_in_schema() {
172        let tool = WebSearchTool::new();
173        let schema = tool.input_schema();
174        assert!(schema.properties.get("query").is_some());
175    }
176
177    #[test]
178    fn test_web_search_tool_has_num_results_in_schema() {
179        let tool = WebSearchTool::new();
180        let schema = tool.input_schema();
181        assert!(schema.properties.get("num_results").is_some());
182    }
183
184    #[tokio::test]
185    async fn test_web_search_tool_requires_query() {
186        let tool = WebSearchTool::new();
187        let input = serde_json::json!({});
188        let context = ToolContext::default();
189
190        let result = tool.execute(input, &context).await;
191        assert!(result.is_err());
192    }
193
194    #[tokio::test]
195    #[ignore] // Requires network access to DuckDuckGo
196    async fn test_web_search_tool_returns_results() {
197        let tool = WebSearchTool::new();
198        let input = serde_json::json!({
199            "query": "Rust programming language"
200        });
201        let context = ToolContext::default();
202
203        let result = tool.execute(input, &context).await;
204        assert!(result.is_ok());
205        let tool_result = result.unwrap();
206        assert!(!tool_result.content.is_empty());
207        // Should contain some expected content
208        assert!(tool_result.content.to_lowercase().contains("rust"));
209    }
210
211    #[tokio::test]
212    #[ignore] // Requires network access to DuckDuckGo
213    async fn test_web_search_tool_respects_num_results() {
214        let tool = WebSearchTool::new();
215        let input = serde_json::json!({
216            "query": "test query",
217            "num_results": 3
218        });
219        let context = ToolContext::default();
220
221        let result = tool.execute(input, &context).await;
222        assert!(result.is_ok());
223    }
224}