1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use parking_lot::RwLock;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::collections::{HashMap, HashSet};
8use std::sync::LazyLock;
9use std::time::{Duration, Instant};
10
11const CACHE_TTL: Duration = Duration::from_secs(15 * 60);
12const DEFAULT_MAX_RESULTS: usize = 10;
13const ABSOLUTE_MAX_RESULTS: usize = 20;
14const USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
15
16#[derive(Debug, Deserialize)]
17struct WebSearchArgs {
18 query: String,
19 #[serde(default)]
20 allowed_domains: Option<Vec<String>>,
21 #[serde(default)]
22 blocked_domains: Option<Vec<String>>,
23 #[serde(default)]
24 max_results: Option<usize>,
25}
26
27struct CachedSearch {
28 results: serde_json::Value,
29 expires_at: Instant,
30}
31
32static SEARCH_CACHE: LazyLock<RwLock<HashMap<String, CachedSearch>>> =
33 LazyLock::new(|| RwLock::new(HashMap::new()));
34
35pub struct WebSearchTool;
36
37impl WebSearchTool {
38 pub fn new() -> Self {
39 Self
40 }
41
42 fn cache_key(
43 query: &str,
44 allowed: &Option<Vec<String>>,
45 blocked: &Option<Vec<String>>,
46 ) -> String {
47 let mut key = query.to_string();
48 if let Some(domains) = allowed {
49 key.push('|');
50 key.push_str(&domains.join(","));
51 }
52 key.push('|');
53 if let Some(domains) = blocked {
54 key.push_str(&domains.join(","));
55 }
56 key
57 }
58
59 fn try_cache(key: &str) -> Option<serde_json::Value> {
60 let cache = SEARCH_CACHE.read();
61 let entry = cache.get(key)?;
62 if entry.expires_at > Instant::now() {
63 Some(entry.results.clone())
64 } else {
65 None
66 }
67 }
68
69 fn put_cache(key: String, results: serde_json::Value) {
70 let mut cache = SEARCH_CACHE.write();
71 cache.insert(
72 key,
73 CachedSearch {
74 results,
75 expires_at: Instant::now() + CACHE_TTL,
76 },
77 );
78 }
79
80 fn decode_duckduckgo_url(raw: &str) -> Option<String> {
81 if let Ok(url) = url::Url::parse(raw) {
82 if let Some(value) = url
83 .query_pairs()
84 .find(|(key, _)| key == "uddg")
85 .map(|(_, value)| value.to_string())
86 {
87 return Some(value);
88 }
89 }
90
91 Some(raw.to_string())
92 }
93
94 fn host_of(url: &str) -> Option<String> {
95 url::Url::parse(url)
96 .ok()
97 .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
98 }
99
100 fn domain_matches(host: &str, domain: &str) -> bool {
101 host == domain || host.ends_with(&format!(".{}", domain))
102 }
103}
104
105impl Default for WebSearchTool {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111#[async_trait]
112impl Tool for WebSearchTool {
113 fn name(&self) -> &str {
114 "WebSearch"
115 }
116
117 fn description(&self) -> &str {
118 "Search DuckDuckGo and return up to 10 filtered results (title, url, domain, snippet) with optional allow/block domain filters."
119 }
120
121 fn mutability(&self) -> crate::ToolMutability {
122 crate::ToolMutability::ReadOnly
123 }
124
125 fn concurrency_safe(&self) -> bool {
126 true
127 }
128
129 fn parameters_schema(&self) -> serde_json::Value {
130 json!({
131 "type": "object",
132 "properties": {
133 "query": {
134 "type": "string",
135 "minLength": 2,
136 "description": "The search query to use"
137 },
138 "allowed_domains": {
139 "type": "array",
140 "items": { "type": "string" },
141 "description": "Only include results from these domains"
142 },
143 "blocked_domains": {
144 "type": "array",
145 "items": { "type": "string" },
146 "description": "Never include results from these domains"
147 },
148 "max_results": {
149 "type": "number",
150 "description": "Maximum results to return (default 10, max 20)"
151 }
152 },
153 "required": ["query"],
154 "additionalProperties": false
155 })
156 }
157
158 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
159 self.execute_with_context(args, ToolExecutionContext::none("WebSearch"))
160 .await
161 }
162
163 async fn execute_with_context(
164 &self,
165 args: serde_json::Value,
166 ctx: ToolExecutionContext<'_>,
167 ) -> Result<ToolResult, ToolError> {
168 let parsed: WebSearchArgs = serde_json::from_value(args)
169 .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebSearch args: {}", e)))?;
170
171 let query = parsed.query.trim();
172 if query.len() < 2 {
173 return Err(ToolError::InvalidArguments(
174 "query must be at least 2 characters".to_string(),
175 ));
176 }
177
178 let allowed_domains = parsed.allowed_domains.filter(|v| !v.is_empty());
179 let blocked_domains = parsed.blocked_domains.filter(|v| !v.is_empty());
180
181 if allowed_domains.is_some() && blocked_domains.is_some() {
183 return Err(ToolError::InvalidArguments(
184 "Cannot specify both allowed_domains and blocked_domains in the same request"
185 .to_string(),
186 ));
187 }
188
189 let max_results = parsed
190 .max_results
191 .unwrap_or(DEFAULT_MAX_RESULTS)
192 .min(ABSOLUTE_MAX_RESULTS);
193
194 let cache_key = Self::cache_key(query, &allowed_domains, &blocked_domains);
196 if let Some(cached) = Self::try_cache(&cache_key) {
197 ctx.emit_tool_token("Using cached search results\n").await;
198 return Ok(ToolResult {
199 success: true,
200 result: cached.to_string(),
201 display_preference: Some("Collapsible".to_string()),
202 });
203 }
204
205 ctx.emit_tool_token(format!("Searching: {}\n", query)).await;
206
207 let client = reqwest::Client::builder()
208 .timeout(Duration::from_secs(30))
209 .build()
210 .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
211
212 let response = client
213 .get("https://duckduckgo.com/html/")
214 .header("User-Agent", USER_AGENT)
215 .query(&[("q", query)])
216 .send()
217 .await
218 .map_err(|e| ToolError::Execution(format!("Web search request failed: {}", e)))?;
219
220 let html = response.text().await.map_err(|e| {
221 ToolError::Execution(format!("Failed to decode web search body: {}", e))
222 })?;
223
224 if html.contains("Unfortunately, bots use DuckDuckGo too") || html.contains("anomaly-modal")
226 {
227 return Err(ToolError::Execution(
228 "Search blocked by anti-bot protection. Please retry.".to_string(),
229 ));
230 }
231
232 let allowed: Option<HashSet<String>> = allowed_domains.map(|domains| {
233 domains
234 .into_iter()
235 .map(|value| value.to_ascii_lowercase())
236 .collect()
237 });
238 let blocked: HashSet<String> = blocked_domains
239 .unwrap_or_default()
240 .into_iter()
241 .map(|value| value.to_ascii_lowercase())
242 .collect();
243
244 let link_re = Regex::new(r#"<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>(.*?)</a>"#)
245 .map_err(|e| {
246 ToolError::Execution(format!("Failed to compile link regex: {}", e))
247 })?;
248 let tag_re = Regex::new(r"(?is)<[^>]+>")
249 .map_err(|e| ToolError::Execution(format!("Failed to compile tag regex: {}", e)))?;
250 let snippet_re =
251 Regex::new(r#"<a[^>]*class="result__snippet"[^>]*href="[^"]*"[^>]*>(.*?)</a>"#)
252 .map_err(|e| {
253 ToolError::Execution(format!("Failed to compile snippet regex: {}", e))
254 })?;
255
256 let href_re = Regex::new(r#"href="([^"]+)""#)
258 .map_err(|e| ToolError::Execution(format!("Failed to compile href regex: {}", e)))?;
259 let mut snippets: HashMap<String, String> = HashMap::new();
260 for cap in snippet_re.captures_iter(&html) {
261 if let Some(href_cap) = cap.get(0) {
262 let href_text = href_cap.as_str();
263 if let Some(url_match) = href_re.find(href_text) {
265 let raw_href = &href_text[url_match.start() + 6..url_match.end() - 1];
266 if let Some(decoded) = Self::decode_duckduckgo_url(raw_href) {
267 let snippet_text = cap
268 .get(1)
269 .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
270 .unwrap_or_default();
271 if !snippet_text.is_empty() {
272 snippets.insert(decoded, snippet_text);
273 }
274 }
275 }
276 }
277 }
278
279 let mut results = Vec::new();
280 for capture in link_re.captures_iter(&html) {
281 let Some(raw_url) = capture.get(1).map(|m| m.as_str()) else {
282 continue;
283 };
284 let Some(url) = Self::decode_duckduckgo_url(raw_url) else {
285 continue;
286 };
287 let Some(host) = Self::host_of(&url) else {
288 continue;
289 };
290
291 if blocked
292 .iter()
293 .any(|blocked_domain| Self::domain_matches(&host, blocked_domain))
294 {
295 continue;
296 }
297 if let Some(allowed_set) = &allowed {
298 if !allowed_set
299 .iter()
300 .any(|allowed_domain| Self::domain_matches(&host, allowed_domain))
301 {
302 continue;
303 }
304 }
305
306 let title = capture
307 .get(2)
308 .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
309 .unwrap_or_else(|| url.clone());
310
311 let snippet = snippets.get(&url).cloned().unwrap_or_default();
312
313 let mut result = json!({
314 "title": title,
315 "url": url,
316 "domain": host,
317 });
318 if !snippet.is_empty() {
319 result["snippet"] = json!(snippet);
320 }
321 results.push(result);
322
323 if results.len() >= max_results {
324 break;
325 }
326 }
327
328 ctx.emit_tool_token(format!(
329 "Found {} results for \"{}\"\n",
330 results.len(),
331 query
332 ))
333 .await;
334
335 let result_value = if results.is_empty() {
336 json!({
337 "query": parsed.query,
338 "results": [],
339 "note": "No results found for this query.",
340 })
341 } else {
342 json!({
343 "query": parsed.query,
344 "results": results,
345 })
346 };
347
348 Self::put_cache(cache_key, result_value.clone());
350
351 let mut result_string = result_value.to_string();
352 result_string.push_str("\n\nREMINDER: You MUST include a Sources section at the end of your response, listing all relevant URLs as markdown hyperlinks: [Title](URL)");
353
354 Ok(ToolResult {
355 success: true,
356 result: result_string,
357 display_preference: Some("Collapsible".to_string()),
358 })
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn domain_matches_supports_subdomains() {
368 assert!(WebSearchTool::domain_matches("example.com", "example.com"));
369 assert!(WebSearchTool::domain_matches(
370 "docs.example.com",
371 "example.com"
372 ));
373 assert!(!WebSearchTool::domain_matches(
374 "notexample.com",
375 "example.com"
376 ));
377 assert!(!WebSearchTool::domain_matches(
378 "evil-example.com",
379 "example.com"
380 ));
381 }
382
383 #[test]
384 fn host_of_normalizes_case() {
385 let host = WebSearchTool::host_of("https://Docs.Example.Com/path").unwrap();
386 assert_eq!(host, "docs.example.com");
387 }
388
389 #[test]
390 fn decode_duckduckgo_url_extracts_uddg_param() {
391 let raw = "https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage&rut=whatever";
392 let decoded = WebSearchTool::decode_duckduckgo_url(raw).unwrap();
393 assert_eq!(decoded, "https://example.com/page");
394 }
395
396 #[test]
397 fn cache_key_is_stable() {
398 let k1 =
399 WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
400 let k2 =
401 WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
402 assert_eq!(k1, k2);
403
404 let k3 = WebSearchTool::cache_key("rust", &None, &Some(vec!["bad.com".to_string()]));
405 assert_ne!(k1, k3);
406 }
407}