1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use parking_lot::RwLock;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::collections::{HashMap, HashSet};
8use std::sync::OnceLock;
9use std::time::{Duration, Instant};
10
11const CACHE_TTL: Duration = Duration::from_secs(15 * 60);
12const DEFAULT_MAX_RESULTS: usize = 10;
13const ABSOLUTE_MAX_RESULTS: usize = 20;
14const USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
15
16#[derive(Debug, Deserialize)]
17struct WebSearchArgs {
18 query: String,
19 #[serde(default)]
20 allowed_domains: Option<Vec<String>>,
21 #[serde(default)]
22 blocked_domains: Option<Vec<String>>,
23 #[serde(default)]
24 max_results: Option<usize>,
25}
26
27struct CachedSearch {
28 results: serde_json::Value,
29 expires_at: Instant,
30}
31
32static SEARCH_CACHE: OnceLock<RwLock<HashMap<String, CachedSearch>>> = OnceLock::new();
33
34fn search_cache() -> &'static RwLock<HashMap<String, CachedSearch>> {
35 SEARCH_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
36}
37
38pub struct WebSearchTool;
39
40impl WebSearchTool {
41 pub fn new() -> Self {
42 Self
43 }
44
45 fn cache_key(
46 query: &str,
47 allowed: &Option<Vec<String>>,
48 blocked: &Option<Vec<String>>,
49 ) -> String {
50 let mut key = query.to_string();
51 if let Some(domains) = allowed {
52 key.push('|');
53 key.push_str(&domains.join(","));
54 }
55 key.push('|');
56 if let Some(domains) = blocked {
57 key.push_str(&domains.join(","));
58 }
59 key
60 }
61
62 fn try_cache(key: &str) -> Option<serde_json::Value> {
63 let cache = search_cache().read();
64 let entry = cache.get(key)?;
65 if entry.expires_at > Instant::now() {
66 Some(entry.results.clone())
67 } else {
68 None
69 }
70 }
71
72 fn put_cache(key: String, results: serde_json::Value) {
73 let mut cache = search_cache().write();
74 cache.insert(
75 key,
76 CachedSearch {
77 results,
78 expires_at: Instant::now() + CACHE_TTL,
79 },
80 );
81 }
82
83 fn decode_duckduckgo_url(raw: &str) -> Option<String> {
84 if let Ok(url) = url::Url::parse(raw) {
85 if let Some(value) = url
86 .query_pairs()
87 .find(|(key, _)| key == "uddg")
88 .map(|(_, value)| value.to_string())
89 {
90 return Some(value);
91 }
92 }
93
94 Some(raw.to_string())
95 }
96
97 fn host_of(url: &str) -> Option<String> {
98 url::Url::parse(url)
99 .ok()
100 .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
101 }
102
103 fn domain_matches(host: &str, domain: &str) -> bool {
104 host == domain || host.ends_with(&format!(".{}", domain))
105 }
106}
107
108impl Default for WebSearchTool {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[async_trait]
115impl Tool for WebSearchTool {
116 fn name(&self) -> &str {
117 "WebSearch"
118 }
119
120 fn description(&self) -> &str {
121 "Search DuckDuckGo and return up to 10 filtered results (title, url, domain, snippet) with optional allow/block domain filters."
122 }
123
124 fn mutability(&self) -> crate::ToolMutability {
125 crate::ToolMutability::ReadOnly
126 }
127
128 fn concurrency_safe(&self) -> bool {
129 true
130 }
131
132 fn parameters_schema(&self) -> serde_json::Value {
133 json!({
134 "type": "object",
135 "properties": {
136 "query": {
137 "type": "string",
138 "minLength": 2,
139 "description": "The search query to use"
140 },
141 "allowed_domains": {
142 "type": "array",
143 "items": { "type": "string" },
144 "description": "Only include results from these domains"
145 },
146 "blocked_domains": {
147 "type": "array",
148 "items": { "type": "string" },
149 "description": "Never include results from these domains"
150 },
151 "max_results": {
152 "type": "number",
153 "description": "Maximum results to return (default 10, max 20)"
154 }
155 },
156 "required": ["query"],
157 "additionalProperties": false
158 })
159 }
160
161 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
162 self.execute_with_context(args, ToolExecutionContext::none("WebSearch"))
163 .await
164 }
165
166 async fn execute_with_context(
167 &self,
168 args: serde_json::Value,
169 ctx: ToolExecutionContext<'_>,
170 ) -> Result<ToolResult, ToolError> {
171 let parsed: WebSearchArgs = serde_json::from_value(args)
172 .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebSearch args: {}", e)))?;
173
174 let query = parsed.query.trim();
175 if query.len() < 2 {
176 return Err(ToolError::InvalidArguments(
177 "query must be at least 2 characters".to_string(),
178 ));
179 }
180
181 let allowed_domains = parsed.allowed_domains.filter(|v| !v.is_empty());
182 let blocked_domains = parsed.blocked_domains.filter(|v| !v.is_empty());
183
184 if allowed_domains.is_some() && blocked_domains.is_some() {
186 return Err(ToolError::InvalidArguments(
187 "Cannot specify both allowed_domains and blocked_domains in the same request"
188 .to_string(),
189 ));
190 }
191
192 let max_results = parsed
193 .max_results
194 .unwrap_or(DEFAULT_MAX_RESULTS)
195 .min(ABSOLUTE_MAX_RESULTS);
196
197 let cache_key = Self::cache_key(query, &allowed_domains, &blocked_domains);
199 if let Some(cached) = Self::try_cache(&cache_key) {
200 ctx.emit_tool_token("Using cached search results\n").await;
201 return Ok(ToolResult {
202 success: true,
203 result: cached.to_string(),
204 display_preference: Some("Collapsible".to_string()),
205 });
206 }
207
208 ctx.emit_tool_token(format!("Searching: {}\n", query)).await;
209
210 let client = reqwest::Client::builder()
211 .timeout(Duration::from_secs(30))
212 .build()
213 .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
214
215 let response = client
216 .get("https://duckduckgo.com/html/")
217 .header("User-Agent", USER_AGENT)
218 .query(&[("q", query)])
219 .send()
220 .await
221 .map_err(|e| ToolError::Execution(format!("Web search request failed: {}", e)))?;
222
223 let html = response.text().await.map_err(|e| {
224 ToolError::Execution(format!("Failed to decode web search body: {}", e))
225 })?;
226
227 if html.contains("Unfortunately, bots use DuckDuckGo too") || html.contains("anomaly-modal")
229 {
230 return Err(ToolError::Execution(
231 "Search blocked by anti-bot protection. Please retry.".to_string(),
232 ));
233 }
234
235 let allowed: Option<HashSet<String>> = allowed_domains.map(|domains| {
236 domains
237 .into_iter()
238 .map(|value| value.to_ascii_lowercase())
239 .collect()
240 });
241 let blocked: HashSet<String> = blocked_domains
242 .unwrap_or_default()
243 .into_iter()
244 .map(|value| value.to_ascii_lowercase())
245 .collect();
246
247 let link_re = Regex::new(r#"<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>(.*?)</a>"#)
248 .map_err(|e| {
249 ToolError::Execution(format!("Failed to compile link regex: {}", e))
250 })?;
251 let tag_re = Regex::new(r"(?is)<[^>]+>")
252 .map_err(|e| ToolError::Execution(format!("Failed to compile tag regex: {}", e)))?;
253 let snippet_re =
254 Regex::new(r#"<a[^>]*class="result__snippet"[^>]*href="[^"]*"[^>]*>(.*?)</a>"#)
255 .map_err(|e| {
256 ToolError::Execution(format!("Failed to compile snippet regex: {}", e))
257 })?;
258
259 let href_re = Regex::new(r#"href="([^"]+)""#)
261 .map_err(|e| ToolError::Execution(format!("Failed to compile href regex: {}", e)))?;
262 let mut snippets: HashMap<String, String> = HashMap::new();
263 for cap in snippet_re.captures_iter(&html) {
264 if let Some(href_cap) = cap.get(0) {
265 let href_text = href_cap.as_str();
266 if let Some(url_match) = href_re.find(href_text) {
268 let raw_href = &href_text[url_match.start() + 6..url_match.end() - 1];
269 if let Some(decoded) = Self::decode_duckduckgo_url(raw_href) {
270 let snippet_text = cap
271 .get(1)
272 .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
273 .unwrap_or_default();
274 if !snippet_text.is_empty() {
275 snippets.insert(decoded, snippet_text);
276 }
277 }
278 }
279 }
280 }
281
282 let mut results = Vec::new();
283 for capture in link_re.captures_iter(&html) {
284 let Some(raw_url) = capture.get(1).map(|m| m.as_str()) else {
285 continue;
286 };
287 let Some(url) = Self::decode_duckduckgo_url(raw_url) else {
288 continue;
289 };
290 let Some(host) = Self::host_of(&url) else {
291 continue;
292 };
293
294 if blocked
295 .iter()
296 .any(|blocked_domain| Self::domain_matches(&host, blocked_domain))
297 {
298 continue;
299 }
300 if let Some(allowed_set) = &allowed {
301 if !allowed_set
302 .iter()
303 .any(|allowed_domain| Self::domain_matches(&host, allowed_domain))
304 {
305 continue;
306 }
307 }
308
309 let title = capture
310 .get(2)
311 .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
312 .unwrap_or_else(|| url.clone());
313
314 let snippet = snippets.get(&url).cloned().unwrap_or_default();
315
316 let mut result = json!({
317 "title": title,
318 "url": url,
319 "domain": host,
320 });
321 if !snippet.is_empty() {
322 result["snippet"] = json!(snippet);
323 }
324 results.push(result);
325
326 if results.len() >= max_results {
327 break;
328 }
329 }
330
331 ctx.emit_tool_token(format!(
332 "Found {} results for \"{}\"\n",
333 results.len(),
334 query
335 ))
336 .await;
337
338 let result_value = if results.is_empty() {
339 json!({
340 "query": parsed.query,
341 "results": [],
342 "note": "No results found for this query.",
343 })
344 } else {
345 json!({
346 "query": parsed.query,
347 "results": results,
348 })
349 };
350
351 Self::put_cache(cache_key, result_value.clone());
353
354 let mut result_string = result_value.to_string();
355 result_string.push_str("\n\nREMINDER: You MUST include a Sources section at the end of your response, listing all relevant URLs as markdown hyperlinks: [Title](URL)");
356
357 Ok(ToolResult {
358 success: true,
359 result: result_string,
360 display_preference: Some("Collapsible".to_string()),
361 })
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn domain_matches_supports_subdomains() {
371 assert!(WebSearchTool::domain_matches("example.com", "example.com"));
372 assert!(WebSearchTool::domain_matches(
373 "docs.example.com",
374 "example.com"
375 ));
376 assert!(!WebSearchTool::domain_matches(
377 "notexample.com",
378 "example.com"
379 ));
380 assert!(!WebSearchTool::domain_matches(
381 "evil-example.com",
382 "example.com"
383 ));
384 }
385
386 #[test]
387 fn host_of_normalizes_case() {
388 let host = WebSearchTool::host_of("https://Docs.Example.Com/path").unwrap();
389 assert_eq!(host, "docs.example.com");
390 }
391
392 #[test]
393 fn decode_duckduckgo_url_extracts_uddg_param() {
394 let raw = "https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage&rut=whatever";
395 let decoded = WebSearchTool::decode_duckduckgo_url(raw).unwrap();
396 assert_eq!(decoded, "https://example.com/page");
397 }
398
399 #[test]
400 fn cache_key_is_stable() {
401 let k1 =
402 WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
403 let k2 =
404 WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
405 assert_eq!(k1, k2);
406
407 let k3 = WebSearchTool::cache_key("rust", &None, &Some(vec!["bad.com".to_string()]));
408 assert_ne!(k1, k3);
409 }
410}