Skip to main content

bamboo_tools/tools/
web_search.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use parking_lot::RwLock;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::collections::{HashMap, HashSet};
8use std::sync::OnceLock;
9use std::time::{Duration, Instant};
10
11const CACHE_TTL: Duration = Duration::from_secs(15 * 60);
12const DEFAULT_MAX_RESULTS: usize = 10;
13const ABSOLUTE_MAX_RESULTS: usize = 20;
14const USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
15
16#[derive(Debug, Deserialize)]
17struct WebSearchArgs {
18    query: String,
19    #[serde(default)]
20    allowed_domains: Option<Vec<String>>,
21    #[serde(default)]
22    blocked_domains: Option<Vec<String>>,
23    #[serde(default)]
24    max_results: Option<usize>,
25}
26
27struct CachedSearch {
28    results: serde_json::Value,
29    expires_at: Instant,
30}
31
32static SEARCH_CACHE: OnceLock<RwLock<HashMap<String, CachedSearch>>> = OnceLock::new();
33
34fn search_cache() -> &'static RwLock<HashMap<String, CachedSearch>> {
35    SEARCH_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
36}
37
38// Static, compile-time-constant patterns: compile each exactly once and reuse.
39// `expect` is safe here because the patterns are hardcoded and verified valid.
40static LINK_RE: OnceLock<Regex> = OnceLock::new();
41static TAG_RE: OnceLock<Regex> = OnceLock::new();
42static SNIPPET_RE: OnceLock<Regex> = OnceLock::new();
43static HREF_RE: OnceLock<Regex> = OnceLock::new();
44
45pub struct WebSearchTool;
46
47impl WebSearchTool {
48    pub fn new() -> Self {
49        Self
50    }
51
52    fn cache_key(
53        query: &str,
54        allowed: &Option<Vec<String>>,
55        blocked: &Option<Vec<String>>,
56    ) -> String {
57        let mut key = query.to_string();
58        if let Some(domains) = allowed {
59            key.push('|');
60            key.push_str(&domains.join(","));
61        }
62        key.push('|');
63        if let Some(domains) = blocked {
64            key.push_str(&domains.join(","));
65        }
66        key
67    }
68
69    fn try_cache(key: &str) -> Option<serde_json::Value> {
70        let cache = search_cache().read();
71        let entry = cache.get(key)?;
72        if entry.expires_at > Instant::now() {
73            Some(entry.results.clone())
74        } else {
75            None
76        }
77    }
78
79    fn put_cache(key: String, results: serde_json::Value) {
80        let mut cache = search_cache().write();
81        cache.insert(
82            key,
83            CachedSearch {
84                results,
85                expires_at: Instant::now() + CACHE_TTL,
86            },
87        );
88    }
89
90    fn decode_duckduckgo_url(raw: &str) -> Option<String> {
91        if let Ok(url) = url::Url::parse(raw) {
92            if let Some(value) = url
93                .query_pairs()
94                .find(|(key, _)| key == "uddg")
95                .map(|(_, value)| value.to_string())
96            {
97                return Some(value);
98            }
99        }
100
101        Some(raw.to_string())
102    }
103
104    fn host_of(url: &str) -> Option<String> {
105        url::Url::parse(url)
106            .ok()
107            .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
108    }
109
110    fn domain_matches(host: &str, domain: &str) -> bool {
111        host == domain || host.ends_with(&format!(".{}", domain))
112    }
113}
114
115impl Default for WebSearchTool {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121#[async_trait]
122impl Tool for WebSearchTool {
123    fn name(&self) -> &str {
124        "WebSearch"
125    }
126
127    fn description(&self) -> &str {
128        "Search DuckDuckGo and return up to 10 filtered results (title, url, domain, snippet) with optional allow/block domain filters."
129    }
130
131    fn mutability(&self) -> crate::ToolMutability {
132        crate::ToolMutability::ReadOnly
133    }
134
135    fn concurrency_safe(&self) -> bool {
136        true
137    }
138
139    fn parameters_schema(&self) -> serde_json::Value {
140        json!({
141            "type": "object",
142            "properties": {
143                "query": {
144                    "type": "string",
145                    "minLength": 2,
146                    "description": "The search query to use"
147                },
148                "allowed_domains": {
149                    "type": "array",
150                    "items": { "type": "string" },
151                    "description": "Only include results from these domains"
152                },
153                "blocked_domains": {
154                    "type": "array",
155                    "items": { "type": "string" },
156                    "description": "Never include results from these domains"
157                },
158                "max_results": {
159                    "type": "number",
160                    "description": "Maximum results to return (default 10, max 20)"
161                }
162            },
163            "required": ["query"],
164            "additionalProperties": false
165        })
166    }
167
168    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
169        self.execute_with_context(args, ToolExecutionContext::none("WebSearch"))
170            .await
171    }
172
173    async fn execute_with_context(
174        &self,
175        args: serde_json::Value,
176        ctx: ToolExecutionContext<'_>,
177    ) -> Result<ToolResult, ToolError> {
178        let parsed: WebSearchArgs = serde_json::from_value(args)
179            .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebSearch args: {}", e)))?;
180
181        let query = parsed.query.trim();
182        if query.len() < 2 {
183            return Err(ToolError::InvalidArguments(
184                "query must be at least 2 characters".to_string(),
185            ));
186        }
187
188        let allowed_domains = parsed.allowed_domains.filter(|v| !v.is_empty());
189        let blocked_domains = parsed.blocked_domains.filter(|v| !v.is_empty());
190
191        // Mutual-exclusion validation
192        if allowed_domains.is_some() && blocked_domains.is_some() {
193            return Err(ToolError::InvalidArguments(
194                "Cannot specify both allowed_domains and blocked_domains in the same request"
195                    .to_string(),
196            ));
197        }
198
199        let max_results = parsed
200            .max_results
201            .unwrap_or(DEFAULT_MAX_RESULTS)
202            .min(ABSOLUTE_MAX_RESULTS);
203
204        // Check cache
205        let cache_key = Self::cache_key(query, &allowed_domains, &blocked_domains);
206        if let Some(cached) = Self::try_cache(&cache_key) {
207            ctx.emit_tool_token("Using cached search results\n").await;
208            return Ok(ToolResult {
209                success: true,
210                result: cached.to_string(),
211                display_preference: Some("Collapsible".to_string()),
212                images: Vec::new(),
213            });
214        }
215
216        ctx.emit_tool_token(format!("Searching: {}\n", query)).await;
217
218        let client = reqwest::Client::builder()
219            .timeout(Duration::from_secs(30))
220            .build()
221            .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
222
223        let response = client
224            .get("https://duckduckgo.com/html/")
225            .header("User-Agent", USER_AGENT)
226            .query(&[("q", query)])
227            .send()
228            .await
229            .map_err(|e| ToolError::Execution(format!("Web search request failed: {}", e)))?;
230
231        let html = response.text().await.map_err(|e| {
232            ToolError::Execution(format!("Failed to decode web search body: {}", e))
233        })?;
234
235        // Detect anti-bot page
236        if html.contains("Unfortunately, bots use DuckDuckGo too") || html.contains("anomaly-modal")
237        {
238            return Err(ToolError::Execution(
239                "Search blocked by anti-bot protection. Please retry.".to_string(),
240            ));
241        }
242
243        let allowed: Option<HashSet<String>> = allowed_domains.map(|domains| {
244            domains
245                .into_iter()
246                .map(|value| value.to_ascii_lowercase())
247                .collect()
248        });
249        let blocked: HashSet<String> = blocked_domains
250            .unwrap_or_default()
251            .into_iter()
252            .map(|value| value.to_ascii_lowercase())
253            .collect();
254
255        let link_re = LINK_RE.get_or_init(|| {
256            Regex::new(r#"<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>(.*?)</a>"#)
257                .expect("valid static regex")
258        });
259        let tag_re =
260            TAG_RE.get_or_init(|| Regex::new(r"(?is)<[^>]+>").expect("valid static regex"));
261        let snippet_re = SNIPPET_RE.get_or_init(|| {
262            Regex::new(r#"<a[^>]*class="result__snippet"[^>]*href="[^"]*"[^>]*>(.*?)</a>"#)
263                .expect("valid static regex")
264        });
265
266        // Build a map of snippet content by href (to match with result links)
267        let href_re =
268            HREF_RE.get_or_init(|| Regex::new(r#"href="([^"]+)""#).expect("valid static regex"));
269        let mut snippets: HashMap<String, String> = HashMap::new();
270        for cap in snippet_re.captures_iter(&html) {
271            if let Some(href_cap) = cap.get(0) {
272                let href_text = href_cap.as_str();
273                // Extract the href URL from the snippet anchor
274                if let Some(url_match) = href_re.find(href_text) {
275                    let raw_href = &href_text[url_match.start() + 6..url_match.end() - 1];
276                    if let Some(decoded) = Self::decode_duckduckgo_url(raw_href) {
277                        let snippet_text = cap
278                            .get(1)
279                            .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
280                            .unwrap_or_default();
281                        if !snippet_text.is_empty() {
282                            snippets.insert(decoded, snippet_text);
283                        }
284                    }
285                }
286            }
287        }
288
289        let mut results = Vec::new();
290        for capture in link_re.captures_iter(&html) {
291            let Some(raw_url) = capture.get(1).map(|m| m.as_str()) else {
292                continue;
293            };
294            let Some(url) = Self::decode_duckduckgo_url(raw_url) else {
295                continue;
296            };
297            let Some(host) = Self::host_of(&url) else {
298                continue;
299            };
300
301            if blocked
302                .iter()
303                .any(|blocked_domain| Self::domain_matches(&host, blocked_domain))
304            {
305                continue;
306            }
307            if let Some(allowed_set) = &allowed {
308                if !allowed_set
309                    .iter()
310                    .any(|allowed_domain| Self::domain_matches(&host, allowed_domain))
311                {
312                    continue;
313                }
314            }
315
316            let title = capture
317                .get(2)
318                .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
319                .unwrap_or_else(|| url.clone());
320
321            let snippet = snippets.get(&url).cloned().unwrap_or_default();
322
323            let mut result = json!({
324                "title": title,
325                "url": url,
326                "domain": host,
327            });
328            if !snippet.is_empty() {
329                result["snippet"] = json!(snippet);
330            }
331            results.push(result);
332
333            if results.len() >= max_results {
334                break;
335            }
336        }
337
338        ctx.emit_tool_token(format!(
339            "Found {} results for \"{}\"\n",
340            results.len(),
341            query
342        ))
343        .await;
344
345        let result_value = if results.is_empty() {
346            json!({
347                "query": parsed.query,
348                "results": [],
349                "note": "No results found for this query.",
350            })
351        } else {
352            json!({
353                "query": parsed.query,
354                "results": results,
355            })
356        };
357
358        // Store in cache
359        Self::put_cache(cache_key, result_value.clone());
360
361        let mut result_string = result_value.to_string();
362        result_string.push_str("\n\nREMINDER: You MUST include a Sources section at the end of your response, listing all relevant URLs as markdown hyperlinks: [Title](URL)");
363
364        Ok(ToolResult {
365            success: true,
366            result: result_string,
367            display_preference: Some("Collapsible".to_string()),
368            images: Vec::new(),
369        })
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn domain_matches_supports_subdomains() {
379        assert!(WebSearchTool::domain_matches("example.com", "example.com"));
380        assert!(WebSearchTool::domain_matches(
381            "docs.example.com",
382            "example.com"
383        ));
384        assert!(!WebSearchTool::domain_matches(
385            "notexample.com",
386            "example.com"
387        ));
388        assert!(!WebSearchTool::domain_matches(
389            "evil-example.com",
390            "example.com"
391        ));
392    }
393
394    #[test]
395    fn host_of_normalizes_case() {
396        let host = WebSearchTool::host_of("https://Docs.Example.Com/path").unwrap();
397        assert_eq!(host, "docs.example.com");
398    }
399
400    #[test]
401    fn decode_duckduckgo_url_extracts_uddg_param() {
402        let raw = "https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage&rut=whatever";
403        let decoded = WebSearchTool::decode_duckduckgo_url(raw).unwrap();
404        assert_eq!(decoded, "https://example.com/page");
405    }
406
407    #[test]
408    fn cache_key_is_stable() {
409        let k1 =
410            WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
411        let k2 =
412            WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
413        assert_eq!(k1, k2);
414
415        let k3 = WebSearchTool::cache_key("rust", &None, &Some(vec!["bad.com".to_string()]));
416        assert_ne!(k1, k3);
417    }
418}