Skip to main content

bamboo_tools/tools/
web_search.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use parking_lot::RwLock;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::collections::{HashMap, HashSet};
8use std::sync::OnceLock;
9use std::time::{Duration, Instant};
10
11const CACHE_TTL: Duration = Duration::from_secs(15 * 60);
12const DEFAULT_MAX_RESULTS: usize = 10;
13const ABSOLUTE_MAX_RESULTS: usize = 20;
14const USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
15
16#[derive(Debug, Deserialize)]
17struct WebSearchArgs {
18    query: String,
19    #[serde(default)]
20    allowed_domains: Option<Vec<String>>,
21    #[serde(default)]
22    blocked_domains: Option<Vec<String>>,
23    #[serde(default)]
24    max_results: Option<usize>,
25}
26
27struct CachedSearch {
28    results: serde_json::Value,
29    expires_at: Instant,
30}
31
32static SEARCH_CACHE: OnceLock<RwLock<HashMap<String, CachedSearch>>> = OnceLock::new();
33
34fn search_cache() -> &'static RwLock<HashMap<String, CachedSearch>> {
35    SEARCH_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
36}
37
38pub struct WebSearchTool;
39
40impl WebSearchTool {
41    pub fn new() -> Self {
42        Self
43    }
44
45    fn cache_key(
46        query: &str,
47        allowed: &Option<Vec<String>>,
48        blocked: &Option<Vec<String>>,
49    ) -> String {
50        let mut key = query.to_string();
51        if let Some(domains) = allowed {
52            key.push('|');
53            key.push_str(&domains.join(","));
54        }
55        key.push('|');
56        if let Some(domains) = blocked {
57            key.push_str(&domains.join(","));
58        }
59        key
60    }
61
62    fn try_cache(key: &str) -> Option<serde_json::Value> {
63        let cache = search_cache().read();
64        let entry = cache.get(key)?;
65        if entry.expires_at > Instant::now() {
66            Some(entry.results.clone())
67        } else {
68            None
69        }
70    }
71
72    fn put_cache(key: String, results: serde_json::Value) {
73        let mut cache = search_cache().write();
74        cache.insert(
75            key,
76            CachedSearch {
77                results,
78                expires_at: Instant::now() + CACHE_TTL,
79            },
80        );
81    }
82
83    fn decode_duckduckgo_url(raw: &str) -> Option<String> {
84        if let Ok(url) = url::Url::parse(raw) {
85            if let Some(value) = url
86                .query_pairs()
87                .find(|(key, _)| key == "uddg")
88                .map(|(_, value)| value.to_string())
89            {
90                return Some(value);
91            }
92        }
93
94        Some(raw.to_string())
95    }
96
97    fn host_of(url: &str) -> Option<String> {
98        url::Url::parse(url)
99            .ok()
100            .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
101    }
102
103    fn domain_matches(host: &str, domain: &str) -> bool {
104        host == domain || host.ends_with(&format!(".{}", domain))
105    }
106}
107
108impl Default for WebSearchTool {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[async_trait]
115impl Tool for WebSearchTool {
116    fn name(&self) -> &str {
117        "WebSearch"
118    }
119
120    fn description(&self) -> &str {
121        "Search DuckDuckGo and return up to 10 filtered results (title, url, domain, snippet) with optional allow/block domain filters."
122    }
123
124    fn mutability(&self) -> crate::ToolMutability {
125        crate::ToolMutability::ReadOnly
126    }
127
128    fn concurrency_safe(&self) -> bool {
129        true
130    }
131
132    fn parameters_schema(&self) -> serde_json::Value {
133        json!({
134            "type": "object",
135            "properties": {
136                "query": {
137                    "type": "string",
138                    "minLength": 2,
139                    "description": "The search query to use"
140                },
141                "allowed_domains": {
142                    "type": "array",
143                    "items": { "type": "string" },
144                    "description": "Only include results from these domains"
145                },
146                "blocked_domains": {
147                    "type": "array",
148                    "items": { "type": "string" },
149                    "description": "Never include results from these domains"
150                },
151                "max_results": {
152                    "type": "number",
153                    "description": "Maximum results to return (default 10, max 20)"
154                }
155            },
156            "required": ["query"],
157            "additionalProperties": false
158        })
159    }
160
161    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
162        self.execute_with_context(args, ToolExecutionContext::none("WebSearch"))
163            .await
164    }
165
166    async fn execute_with_context(
167        &self,
168        args: serde_json::Value,
169        ctx: ToolExecutionContext<'_>,
170    ) -> Result<ToolResult, ToolError> {
171        let parsed: WebSearchArgs = serde_json::from_value(args)
172            .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebSearch args: {}", e)))?;
173
174        let query = parsed.query.trim();
175        if query.len() < 2 {
176            return Err(ToolError::InvalidArguments(
177                "query must be at least 2 characters".to_string(),
178            ));
179        }
180
181        let allowed_domains = parsed.allowed_domains.filter(|v| !v.is_empty());
182        let blocked_domains = parsed.blocked_domains.filter(|v| !v.is_empty());
183
184        // Mutual-exclusion validation
185        if allowed_domains.is_some() && blocked_domains.is_some() {
186            return Err(ToolError::InvalidArguments(
187                "Cannot specify both allowed_domains and blocked_domains in the same request"
188                    .to_string(),
189            ));
190        }
191
192        let max_results = parsed
193            .max_results
194            .unwrap_or(DEFAULT_MAX_RESULTS)
195            .min(ABSOLUTE_MAX_RESULTS);
196
197        // Check cache
198        let cache_key = Self::cache_key(query, &allowed_domains, &blocked_domains);
199        if let Some(cached) = Self::try_cache(&cache_key) {
200            ctx.emit_tool_token("Using cached search results\n").await;
201            return Ok(ToolResult {
202                success: true,
203                result: cached.to_string(),
204                display_preference: Some("Collapsible".to_string()),
205                images: Vec::new(),
206            });
207        }
208
209        ctx.emit_tool_token(format!("Searching: {}\n", query)).await;
210
211        let client = reqwest::Client::builder()
212            .timeout(Duration::from_secs(30))
213            .build()
214            .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
215
216        let response = client
217            .get("https://duckduckgo.com/html/")
218            .header("User-Agent", USER_AGENT)
219            .query(&[("q", query)])
220            .send()
221            .await
222            .map_err(|e| ToolError::Execution(format!("Web search request failed: {}", e)))?;
223
224        let html = response.text().await.map_err(|e| {
225            ToolError::Execution(format!("Failed to decode web search body: {}", e))
226        })?;
227
228        // Detect anti-bot page
229        if html.contains("Unfortunately, bots use DuckDuckGo too") || html.contains("anomaly-modal")
230        {
231            return Err(ToolError::Execution(
232                "Search blocked by anti-bot protection. Please retry.".to_string(),
233            ));
234        }
235
236        let allowed: Option<HashSet<String>> = allowed_domains.map(|domains| {
237            domains
238                .into_iter()
239                .map(|value| value.to_ascii_lowercase())
240                .collect()
241        });
242        let blocked: HashSet<String> = blocked_domains
243            .unwrap_or_default()
244            .into_iter()
245            .map(|value| value.to_ascii_lowercase())
246            .collect();
247
248        let link_re = Regex::new(r#"<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>(.*?)</a>"#)
249            .map_err(|e| {
250            ToolError::Execution(format!("Failed to compile link regex: {}", e))
251        })?;
252        let tag_re = Regex::new(r"(?is)<[^>]+>")
253            .map_err(|e| ToolError::Execution(format!("Failed to compile tag regex: {}", e)))?;
254        let snippet_re =
255            Regex::new(r#"<a[^>]*class="result__snippet"[^>]*href="[^"]*"[^>]*>(.*?)</a>"#)
256                .map_err(|e| {
257                    ToolError::Execution(format!("Failed to compile snippet regex: {}", e))
258                })?;
259
260        // Build a map of snippet content by href (to match with result links)
261        let href_re = Regex::new(r#"href="([^"]+)""#)
262            .map_err(|e| ToolError::Execution(format!("Failed to compile href regex: {}", e)))?;
263        let mut snippets: HashMap<String, String> = HashMap::new();
264        for cap in snippet_re.captures_iter(&html) {
265            if let Some(href_cap) = cap.get(0) {
266                let href_text = href_cap.as_str();
267                // Extract the href URL from the snippet anchor
268                if let Some(url_match) = href_re.find(href_text) {
269                    let raw_href = &href_text[url_match.start() + 6..url_match.end() - 1];
270                    if let Some(decoded) = Self::decode_duckduckgo_url(raw_href) {
271                        let snippet_text = cap
272                            .get(1)
273                            .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
274                            .unwrap_or_default();
275                        if !snippet_text.is_empty() {
276                            snippets.insert(decoded, snippet_text);
277                        }
278                    }
279                }
280            }
281        }
282
283        let mut results = Vec::new();
284        for capture in link_re.captures_iter(&html) {
285            let Some(raw_url) = capture.get(1).map(|m| m.as_str()) else {
286                continue;
287            };
288            let Some(url) = Self::decode_duckduckgo_url(raw_url) else {
289                continue;
290            };
291            let Some(host) = Self::host_of(&url) else {
292                continue;
293            };
294
295            if blocked
296                .iter()
297                .any(|blocked_domain| Self::domain_matches(&host, blocked_domain))
298            {
299                continue;
300            }
301            if let Some(allowed_set) = &allowed {
302                if !allowed_set
303                    .iter()
304                    .any(|allowed_domain| Self::domain_matches(&host, allowed_domain))
305                {
306                    continue;
307                }
308            }
309
310            let title = capture
311                .get(2)
312                .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
313                .unwrap_or_else(|| url.clone());
314
315            let snippet = snippets.get(&url).cloned().unwrap_or_default();
316
317            let mut result = json!({
318                "title": title,
319                "url": url,
320                "domain": host,
321            });
322            if !snippet.is_empty() {
323                result["snippet"] = json!(snippet);
324            }
325            results.push(result);
326
327            if results.len() >= max_results {
328                break;
329            }
330        }
331
332        ctx.emit_tool_token(format!(
333            "Found {} results for \"{}\"\n",
334            results.len(),
335            query
336        ))
337        .await;
338
339        let result_value = if results.is_empty() {
340            json!({
341                "query": parsed.query,
342                "results": [],
343                "note": "No results found for this query.",
344            })
345        } else {
346            json!({
347                "query": parsed.query,
348                "results": results,
349            })
350        };
351
352        // Store in cache
353        Self::put_cache(cache_key, result_value.clone());
354
355        let mut result_string = result_value.to_string();
356        result_string.push_str("\n\nREMINDER: You MUST include a Sources section at the end of your response, listing all relevant URLs as markdown hyperlinks: [Title](URL)");
357
358        Ok(ToolResult {
359            success: true,
360            result: result_string,
361            display_preference: Some("Collapsible".to_string()),
362            images: Vec::new(),
363        })
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn domain_matches_supports_subdomains() {
373        assert!(WebSearchTool::domain_matches("example.com", "example.com"));
374        assert!(WebSearchTool::domain_matches(
375            "docs.example.com",
376            "example.com"
377        ));
378        assert!(!WebSearchTool::domain_matches(
379            "notexample.com",
380            "example.com"
381        ));
382        assert!(!WebSearchTool::domain_matches(
383            "evil-example.com",
384            "example.com"
385        ));
386    }
387
388    #[test]
389    fn host_of_normalizes_case() {
390        let host = WebSearchTool::host_of("https://Docs.Example.Com/path").unwrap();
391        assert_eq!(host, "docs.example.com");
392    }
393
394    #[test]
395    fn decode_duckduckgo_url_extracts_uddg_param() {
396        let raw = "https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage&rut=whatever";
397        let decoded = WebSearchTool::decode_duckduckgo_url(raw).unwrap();
398        assert_eq!(decoded, "https://example.com/page");
399    }
400
401    #[test]
402    fn cache_key_is_stable() {
403        let k1 =
404            WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
405        let k2 =
406            WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
407        assert_eq!(k1, k2);
408
409        let k3 = WebSearchTool::cache_key("rust", &None, &Some(vec!["bad.com".to_string()]));
410        assert_ne!(k1, k3);
411    }
412}