1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use parking_lot::RwLock;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::collections::{HashMap, HashSet};
8use std::sync::OnceLock;
9use std::time::{Duration, Instant};
10
11const CACHE_TTL: Duration = Duration::from_secs(15 * 60);
12const DEFAULT_MAX_RESULTS: usize = 10;
13const ABSOLUTE_MAX_RESULTS: usize = 20;
14const USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
15
16#[derive(Debug, Deserialize)]
17struct WebSearchArgs {
18 query: String,
19 #[serde(default)]
20 allowed_domains: Option<Vec<String>>,
21 #[serde(default)]
22 blocked_domains: Option<Vec<String>>,
23 #[serde(default)]
24 max_results: Option<usize>,
25}
26
27struct CachedSearch {
28 results: serde_json::Value,
29 expires_at: Instant,
30}
31
32static SEARCH_CACHE: OnceLock<RwLock<HashMap<String, CachedSearch>>> = OnceLock::new();
33
34fn search_cache() -> &'static RwLock<HashMap<String, CachedSearch>> {
35 SEARCH_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
36}
37
38static LINK_RE: OnceLock<Regex> = OnceLock::new();
41static TAG_RE: OnceLock<Regex> = OnceLock::new();
42static SNIPPET_RE: OnceLock<Regex> = OnceLock::new();
43static HREF_RE: OnceLock<Regex> = OnceLock::new();
44
45pub struct WebSearchTool;
46
47impl WebSearchTool {
48 pub fn new() -> Self {
49 Self
50 }
51
52 fn cache_key(
53 query: &str,
54 allowed: &Option<Vec<String>>,
55 blocked: &Option<Vec<String>>,
56 ) -> String {
57 let mut key = query.to_string();
58 if let Some(domains) = allowed {
59 key.push('|');
60 key.push_str(&domains.join(","));
61 }
62 key.push('|');
63 if let Some(domains) = blocked {
64 key.push_str(&domains.join(","));
65 }
66 key
67 }
68
69 fn try_cache(key: &str) -> Option<serde_json::Value> {
70 let cache = search_cache().read();
71 let entry = cache.get(key)?;
72 if entry.expires_at > Instant::now() {
73 Some(entry.results.clone())
74 } else {
75 None
76 }
77 }
78
79 fn put_cache(key: String, results: serde_json::Value) {
80 let mut cache = search_cache().write();
81 cache.insert(
82 key,
83 CachedSearch {
84 results,
85 expires_at: Instant::now() + CACHE_TTL,
86 },
87 );
88 }
89
90 fn decode_duckduckgo_url(raw: &str) -> Option<String> {
91 if let Ok(url) = url::Url::parse(raw) {
92 if let Some(value) = url
93 .query_pairs()
94 .find(|(key, _)| key == "uddg")
95 .map(|(_, value)| value.to_string())
96 {
97 return Some(value);
98 }
99 }
100
101 Some(raw.to_string())
102 }
103
104 fn host_of(url: &str) -> Option<String> {
105 url::Url::parse(url)
106 .ok()
107 .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
108 }
109
110 fn domain_matches(host: &str, domain: &str) -> bool {
111 host == domain || host.ends_with(&format!(".{}", domain))
112 }
113}
114
115impl Default for WebSearchTool {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121#[async_trait]
122impl Tool for WebSearchTool {
123 fn name(&self) -> &str {
124 "WebSearch"
125 }
126
127 fn description(&self) -> &str {
128 "Search DuckDuckGo and return up to 10 filtered results (title, url, domain, snippet) with optional allow/block domain filters."
129 }
130
131 fn mutability(&self) -> crate::ToolMutability {
132 crate::ToolMutability::ReadOnly
133 }
134
135 fn concurrency_safe(&self) -> bool {
136 true
137 }
138
139 fn parameters_schema(&self) -> serde_json::Value {
140 json!({
141 "type": "object",
142 "properties": {
143 "query": {
144 "type": "string",
145 "minLength": 2,
146 "description": "The search query to use"
147 },
148 "allowed_domains": {
149 "type": "array",
150 "items": { "type": "string" },
151 "description": "Only include results from these domains"
152 },
153 "blocked_domains": {
154 "type": "array",
155 "items": { "type": "string" },
156 "description": "Never include results from these domains"
157 },
158 "max_results": {
159 "type": "number",
160 "description": "Maximum results to return (default 10, max 20)"
161 }
162 },
163 "required": ["query"],
164 "additionalProperties": false
165 })
166 }
167
168 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
169 self.execute_with_context(args, ToolExecutionContext::none("WebSearch"))
170 .await
171 }
172
173 async fn execute_with_context(
174 &self,
175 args: serde_json::Value,
176 ctx: ToolExecutionContext<'_>,
177 ) -> Result<ToolResult, ToolError> {
178 let parsed: WebSearchArgs = serde_json::from_value(args)
179 .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebSearch args: {}", e)))?;
180
181 let query = parsed.query.trim();
182 if query.len() < 2 {
183 return Err(ToolError::InvalidArguments(
184 "query must be at least 2 characters".to_string(),
185 ));
186 }
187
188 let allowed_domains = parsed.allowed_domains.filter(|v| !v.is_empty());
189 let blocked_domains = parsed.blocked_domains.filter(|v| !v.is_empty());
190
191 if allowed_domains.is_some() && blocked_domains.is_some() {
193 return Err(ToolError::InvalidArguments(
194 "Cannot specify both allowed_domains and blocked_domains in the same request"
195 .to_string(),
196 ));
197 }
198
199 let max_results = parsed
200 .max_results
201 .unwrap_or(DEFAULT_MAX_RESULTS)
202 .min(ABSOLUTE_MAX_RESULTS);
203
204 let cache_key = Self::cache_key(query, &allowed_domains, &blocked_domains);
206 if let Some(cached) = Self::try_cache(&cache_key) {
207 ctx.emit_tool_token("Using cached search results\n").await;
208 return Ok(ToolResult {
209 success: true,
210 result: cached.to_string(),
211 display_preference: Some("Collapsible".to_string()),
212 images: Vec::new(),
213 });
214 }
215
216 ctx.emit_tool_token(format!("Searching: {}\n", query)).await;
217
218 let client = reqwest::Client::builder()
219 .timeout(Duration::from_secs(30))
220 .build()
221 .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
222
223 let response = client
224 .get("https://duckduckgo.com/html/")
225 .header("User-Agent", USER_AGENT)
226 .query(&[("q", query)])
227 .send()
228 .await
229 .map_err(|e| ToolError::Execution(format!("Web search request failed: {}", e)))?;
230
231 let html = response.text().await.map_err(|e| {
232 ToolError::Execution(format!("Failed to decode web search body: {}", e))
233 })?;
234
235 if html.contains("Unfortunately, bots use DuckDuckGo too") || html.contains("anomaly-modal")
237 {
238 return Err(ToolError::Execution(
239 "Search blocked by anti-bot protection. Please retry.".to_string(),
240 ));
241 }
242
243 let allowed: Option<HashSet<String>> = allowed_domains.map(|domains| {
244 domains
245 .into_iter()
246 .map(|value| value.to_ascii_lowercase())
247 .collect()
248 });
249 let blocked: HashSet<String> = blocked_domains
250 .unwrap_or_default()
251 .into_iter()
252 .map(|value| value.to_ascii_lowercase())
253 .collect();
254
255 let link_re = LINK_RE.get_or_init(|| {
256 Regex::new(r#"<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>(.*?)</a>"#)
257 .expect("valid static regex")
258 });
259 let tag_re =
260 TAG_RE.get_or_init(|| Regex::new(r"(?is)<[^>]+>").expect("valid static regex"));
261 let snippet_re = SNIPPET_RE.get_or_init(|| {
262 Regex::new(r#"<a[^>]*class="result__snippet"[^>]*href="[^"]*"[^>]*>(.*?)</a>"#)
263 .expect("valid static regex")
264 });
265
266 let href_re =
268 HREF_RE.get_or_init(|| Regex::new(r#"href="([^"]+)""#).expect("valid static regex"));
269 let mut snippets: HashMap<String, String> = HashMap::new();
270 for cap in snippet_re.captures_iter(&html) {
271 if let Some(href_cap) = cap.get(0) {
272 let href_text = href_cap.as_str();
273 if let Some(url_match) = href_re.find(href_text) {
275 let raw_href = &href_text[url_match.start() + 6..url_match.end() - 1];
276 if let Some(decoded) = Self::decode_duckduckgo_url(raw_href) {
277 let snippet_text = cap
278 .get(1)
279 .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
280 .unwrap_or_default();
281 if !snippet_text.is_empty() {
282 snippets.insert(decoded, snippet_text);
283 }
284 }
285 }
286 }
287 }
288
289 let mut results = Vec::new();
290 for capture in link_re.captures_iter(&html) {
291 let Some(raw_url) = capture.get(1).map(|m| m.as_str()) else {
292 continue;
293 };
294 let Some(url) = Self::decode_duckduckgo_url(raw_url) else {
295 continue;
296 };
297 let Some(host) = Self::host_of(&url) else {
298 continue;
299 };
300
301 if blocked
302 .iter()
303 .any(|blocked_domain| Self::domain_matches(&host, blocked_domain))
304 {
305 continue;
306 }
307 if let Some(allowed_set) = &allowed {
308 if !allowed_set
309 .iter()
310 .any(|allowed_domain| Self::domain_matches(&host, allowed_domain))
311 {
312 continue;
313 }
314 }
315
316 let title = capture
317 .get(2)
318 .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
319 .unwrap_or_else(|| url.clone());
320
321 let snippet = snippets.get(&url).cloned().unwrap_or_default();
322
323 let mut result = json!({
324 "title": title,
325 "url": url,
326 "domain": host,
327 });
328 if !snippet.is_empty() {
329 result["snippet"] = json!(snippet);
330 }
331 results.push(result);
332
333 if results.len() >= max_results {
334 break;
335 }
336 }
337
338 ctx.emit_tool_token(format!(
339 "Found {} results for \"{}\"\n",
340 results.len(),
341 query
342 ))
343 .await;
344
345 let result_value = if results.is_empty() {
346 json!({
347 "query": parsed.query,
348 "results": [],
349 "note": "No results found for this query.",
350 })
351 } else {
352 json!({
353 "query": parsed.query,
354 "results": results,
355 })
356 };
357
358 Self::put_cache(cache_key, result_value.clone());
360
361 let mut result_string = result_value.to_string();
362 result_string.push_str("\n\nREMINDER: You MUST include a Sources section at the end of your response, listing all relevant URLs as markdown hyperlinks: [Title](URL)");
363
364 Ok(ToolResult {
365 success: true,
366 result: result_string,
367 display_preference: Some("Collapsible".to_string()),
368 images: Vec::new(),
369 })
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn domain_matches_supports_subdomains() {
379 assert!(WebSearchTool::domain_matches("example.com", "example.com"));
380 assert!(WebSearchTool::domain_matches(
381 "docs.example.com",
382 "example.com"
383 ));
384 assert!(!WebSearchTool::domain_matches(
385 "notexample.com",
386 "example.com"
387 ));
388 assert!(!WebSearchTool::domain_matches(
389 "evil-example.com",
390 "example.com"
391 ));
392 }
393
394 #[test]
395 fn host_of_normalizes_case() {
396 let host = WebSearchTool::host_of("https://Docs.Example.Com/path").unwrap();
397 assert_eq!(host, "docs.example.com");
398 }
399
400 #[test]
401 fn decode_duckduckgo_url_extracts_uddg_param() {
402 let raw = "https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage&rut=whatever";
403 let decoded = WebSearchTool::decode_duckduckgo_url(raw).unwrap();
404 assert_eq!(decoded, "https://example.com/page");
405 }
406
407 #[test]
408 fn cache_key_is_stable() {
409 let k1 =
410 WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
411 let k2 =
412 WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
413 assert_eq!(k1, k2);
414
415 let k3 = WebSearchTool::cache_key("rust", &None, &Some(vec!["bad.com".to_string()]));
416 assert_ne!(k1, k3);
417 }
418}