ai_agent/tools/
web_search.rs1use crate::types::*;
2use crate::utils::http::get_user_agent;
3use regex::Regex;
4use reqwest::Client;
5
6pub struct WebSearchTool {
7 client: Client,
8}
9
10impl WebSearchTool {
11 pub fn new() -> Self {
12 let client = Client::builder()
13 .timeout(std::time::Duration::from_secs(15))
14 .user_agent(get_user_agent())
15 .build()
16 .expect("Failed to create HTTP client");
17 Self { client }
18 }
19
20 pub fn name(&self) -> &str {
21 "WebSearch"
22 }
23
24 pub fn description(&self) -> &str {
25 "Search the web for information. Returns search results with titles, URLs, and snippets."
26 }
27
28 pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
29 "WebSearch".to_string()
30 }
31
32 pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
33 input.and_then(|inp| inp["query"].as_str().map(String::from))
34 }
35
36 pub fn render_tool_result_message(
37 &self,
38 content: &serde_json::Value,
39 ) -> Option<String> {
40 let text = content["content"].as_str()?;
41 let lines = text.lines().count();
42 Some(format!("{} lines", lines))
43 }
44
45 pub fn input_schema(&self) -> ToolInputSchema {
46 ToolInputSchema {
47 schema_type: "object".to_string(),
48 properties: serde_json::json!({
49 "query": {
50 "type": "string",
51 "description": "The search query"
52 },
53 "num_results": {
54 "type": "number",
55 "description": "Number of results to return (default: 5)"
56 }
57 }),
58 required: Some(vec!["query".to_string()]),
59 }
60 }
61
62 pub async fn execute(
63 &self,
64 input: serde_json::Value,
65 _context: &ToolContext,
66 ) -> Result<ToolResult, crate::error::AgentError> {
67 let query = input["query"]
68 .as_str()
69 .ok_or_else(|| crate::error::AgentError::Tool("query is required".to_string()))?;
70
71 let num_results = input["num_results"].as_u64().unwrap_or(5) as usize;
72
73 let encoded = urlencoding::encode(query);
75 let url = format!("https://html.duckduckgo.com/html/?q={}", encoded);
76
77 let response = self
78 .client
79 .get(&url)
80 .send()
81 .await
82 .map_err(|e| crate::error::AgentError::Tool(format!("Search error: {}", e)))?;
83
84 if !response.status().is_success() {
85 return Ok(ToolResult {
86 result_type: "text".to_string(),
87 tool_use_id: "".to_string(),
88 content: format!("Search failed: HTTP {}", response.status().as_u16()),
89 is_error: Some(true),
90 was_persisted: None,
91 });
92 }
93
94 let html = response.text().await.map_err(|e| {
95 crate::error::AgentError::Tool(format!("Error reading search results: {}", e))
96 })?;
97
98 let result_regex =
100 Regex::new(r#"<a rel="nofollow" class="result__a" href="([^"]*)"[^>]*>([\s\S]*?)</a>"#)
101 .unwrap();
102 let snippet_regex =
103 Regex::new(r#"<a class="result__snippet"[^>]*>([\s\S]*?)</a>"#).unwrap();
104
105 let mut links: Vec<(String, String)> = Vec::new();
106 for cap in result_regex.captures_iter(&html) {
107 if let (Some(href), Some(title)) = (cap.get(1), cap.get(2)) {
108 let href = href.as_str().to_string();
109 let title = title.as_str().replace("<[^>]+>", "").trim().to_string();
110 if !href.is_empty() && !title.is_empty() && !href.contains("duckduckgo.com") {
111 links.push((title, href));
112 }
113 }
114 }
115
116 let mut snippets: Vec<String> = Vec::new();
117 for cap in snippet_regex.captures_iter(&html) {
118 if let Some(snippet) = cap.get(1) {
119 let snippet_text = snippet.as_str().replace("<[^>]+>", "").trim().to_string();
120 snippets.push(snippet_text);
121 }
122 }
123
124 let mut results: Vec<String> = Vec::new();
125 let num_results = std::cmp::min(num_results, links.len());
126
127 for i in 0..num_results {
128 let (title, url) = &links[i];
129 let mut entry = format!("{}. {}\n {}", i + 1, title, url);
130 if let Some(snippet) = snippets.get(i) {
131 if !snippet.is_empty() {
132 entry.push_str(&format!("\n {}", snippet));
133 }
134 }
135 results.push(entry);
136 }
137
138 let content = if results.is_empty() {
139 format!("No results found for \"{}\"", query)
140 } else {
141 results.join("\n\n")
142 };
143
144 Ok(ToolResult {
145 result_type: "text".to_string(),
146 tool_use_id: "".to_string(),
147 content,
148 is_error: None,
149 was_persisted: None,
150 })
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn test_web_search_tool_name() {
160 let tool = WebSearchTool::new();
161 assert_eq!(tool.name(), "WebSearch");
162 }
163
164 #[test]
165 fn test_web_search_tool_description_contains_search() {
166 let tool = WebSearchTool::new();
167 assert!(tool.description().to_lowercase().contains("search"));
168 }
169
170 #[test]
171 fn test_web_search_tool_has_query_in_schema() {
172 let tool = WebSearchTool::new();
173 let schema = tool.input_schema();
174 assert!(schema.properties.get("query").is_some());
175 }
176
177 #[test]
178 fn test_web_search_tool_has_num_results_in_schema() {
179 let tool = WebSearchTool::new();
180 let schema = tool.input_schema();
181 assert!(schema.properties.get("num_results").is_some());
182 }
183
184 #[tokio::test]
185 async fn test_web_search_tool_requires_query() {
186 let tool = WebSearchTool::new();
187 let input = serde_json::json!({});
188 let context = ToolContext::default();
189
190 let result = tool.execute(input, &context).await;
191 assert!(result.is_err());
192 }
193
194 #[tokio::test]
195 #[ignore] async fn test_web_search_tool_returns_results() {
197 let tool = WebSearchTool::new();
198 let input = serde_json::json!({
199 "query": "Rust programming language"
200 });
201 let context = ToolContext::default();
202
203 let result = tool.execute(input, &context).await;
204 assert!(result.is_ok());
205 let tool_result = result.unwrap();
206 assert!(!tool_result.content.is_empty());
207 assert!(tool_result.content.to_lowercase().contains("rust"));
209 }
210
211 #[tokio::test]
212 #[ignore] async fn test_web_search_tool_respects_num_results() {
214 let tool = WebSearchTool::new();
215 let input = serde_json::json!({
216 "query": "test query",
217 "num_results": 3
218 });
219 let context = ToolContext::default();
220
221 let result = tool.execute(input, &context).await;
222 assert!(result.is_ok());
223 }
224}