Skip to main content

bamboo_tools/tools/
web_search.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use parking_lot::RwLock;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::collections::{HashMap, HashSet};
8use std::sync::OnceLock;
9use std::time::{Duration, Instant};
10
11const CACHE_TTL: Duration = Duration::from_secs(15 * 60);
12const DEFAULT_MAX_RESULTS: usize = 10;
13const ABSOLUTE_MAX_RESULTS: usize = 20;
14const USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
15
16#[derive(Debug, Deserialize)]
17struct WebSearchArgs {
18    query: String,
19    #[serde(default)]
20    allowed_domains: Option<Vec<String>>,
21    #[serde(default)]
22    blocked_domains: Option<Vec<String>>,
23    #[serde(default)]
24    max_results: Option<usize>,
25}
26
27struct CachedSearch {
28    results: serde_json::Value,
29    expires_at: Instant,
30}
31
32static SEARCH_CACHE: OnceLock<RwLock<HashMap<String, CachedSearch>>> = OnceLock::new();
33
34fn search_cache() -> &'static RwLock<HashMap<String, CachedSearch>> {
35    SEARCH_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
36}
37
38pub struct WebSearchTool;
39
40impl WebSearchTool {
41    pub fn new() -> Self {
42        Self
43    }
44
45    fn cache_key(
46        query: &str,
47        allowed: &Option<Vec<String>>,
48        blocked: &Option<Vec<String>>,
49    ) -> String {
50        let mut key = query.to_string();
51        if let Some(domains) = allowed {
52            key.push('|');
53            key.push_str(&domains.join(","));
54        }
55        key.push('|');
56        if let Some(domains) = blocked {
57            key.push_str(&domains.join(","));
58        }
59        key
60    }
61
62    fn try_cache(key: &str) -> Option<serde_json::Value> {
63        let cache = search_cache().read();
64        let entry = cache.get(key)?;
65        if entry.expires_at > Instant::now() {
66            Some(entry.results.clone())
67        } else {
68            None
69        }
70    }
71
72    fn put_cache(key: String, results: serde_json::Value) {
73        let mut cache = search_cache().write();
74        cache.insert(
75            key,
76            CachedSearch {
77                results,
78                expires_at: Instant::now() + CACHE_TTL,
79            },
80        );
81    }
82
83    fn decode_duckduckgo_url(raw: &str) -> Option<String> {
84        if let Ok(url) = url::Url::parse(raw) {
85            if let Some(value) = url
86                .query_pairs()
87                .find(|(key, _)| key == "uddg")
88                .map(|(_, value)| value.to_string())
89            {
90                return Some(value);
91            }
92        }
93
94        Some(raw.to_string())
95    }
96
97    fn host_of(url: &str) -> Option<String> {
98        url::Url::parse(url)
99            .ok()
100            .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
101    }
102
103    fn domain_matches(host: &str, domain: &str) -> bool {
104        host == domain || host.ends_with(&format!(".{}", domain))
105    }
106}
107
108impl Default for WebSearchTool {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[async_trait]
115impl Tool for WebSearchTool {
116    fn name(&self) -> &str {
117        "WebSearch"
118    }
119
120    fn description(&self) -> &str {
121        "Search DuckDuckGo and return up to 10 filtered results (title, url, domain, snippet) with optional allow/block domain filters."
122    }
123
124    fn mutability(&self) -> crate::ToolMutability {
125        crate::ToolMutability::ReadOnly
126    }
127
128    fn concurrency_safe(&self) -> bool {
129        true
130    }
131
132    fn parameters_schema(&self) -> serde_json::Value {
133        json!({
134            "type": "object",
135            "properties": {
136                "query": {
137                    "type": "string",
138                    "minLength": 2,
139                    "description": "The search query to use"
140                },
141                "allowed_domains": {
142                    "type": "array",
143                    "items": { "type": "string" },
144                    "description": "Only include results from these domains"
145                },
146                "blocked_domains": {
147                    "type": "array",
148                    "items": { "type": "string" },
149                    "description": "Never include results from these domains"
150                },
151                "max_results": {
152                    "type": "number",
153                    "description": "Maximum results to return (default 10, max 20)"
154                }
155            },
156            "required": ["query"],
157            "additionalProperties": false
158        })
159    }
160
161    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
162        self.execute_with_context(args, ToolExecutionContext::none("WebSearch"))
163            .await
164    }
165
166    async fn execute_with_context(
167        &self,
168        args: serde_json::Value,
169        ctx: ToolExecutionContext<'_>,
170    ) -> Result<ToolResult, ToolError> {
171        let parsed: WebSearchArgs = serde_json::from_value(args)
172            .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebSearch args: {}", e)))?;
173
174        let query = parsed.query.trim();
175        if query.len() < 2 {
176            return Err(ToolError::InvalidArguments(
177                "query must be at least 2 characters".to_string(),
178            ));
179        }
180
181        let allowed_domains = parsed.allowed_domains.filter(|v| !v.is_empty());
182        let blocked_domains = parsed.blocked_domains.filter(|v| !v.is_empty());
183
184        // Mutual-exclusion validation
185        if allowed_domains.is_some() && blocked_domains.is_some() {
186            return Err(ToolError::InvalidArguments(
187                "Cannot specify both allowed_domains and blocked_domains in the same request"
188                    .to_string(),
189            ));
190        }
191
192        let max_results = parsed
193            .max_results
194            .unwrap_or(DEFAULT_MAX_RESULTS)
195            .min(ABSOLUTE_MAX_RESULTS);
196
197        // Check cache
198        let cache_key = Self::cache_key(query, &allowed_domains, &blocked_domains);
199        if let Some(cached) = Self::try_cache(&cache_key) {
200            ctx.emit_tool_token("Using cached search results\n").await;
201            return Ok(ToolResult {
202                success: true,
203                result: cached.to_string(),
204                display_preference: Some("Collapsible".to_string()),
205            });
206        }
207
208        ctx.emit_tool_token(format!("Searching: {}\n", query)).await;
209
210        let client = reqwest::Client::builder()
211            .timeout(Duration::from_secs(30))
212            .build()
213            .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
214
215        let response = client
216            .get("https://duckduckgo.com/html/")
217            .header("User-Agent", USER_AGENT)
218            .query(&[("q", query)])
219            .send()
220            .await
221            .map_err(|e| ToolError::Execution(format!("Web search request failed: {}", e)))?;
222
223        let html = response.text().await.map_err(|e| {
224            ToolError::Execution(format!("Failed to decode web search body: {}", e))
225        })?;
226
227        // Detect anti-bot page
228        if html.contains("Unfortunately, bots use DuckDuckGo too") || html.contains("anomaly-modal")
229        {
230            return Err(ToolError::Execution(
231                "Search blocked by anti-bot protection. Please retry.".to_string(),
232            ));
233        }
234
235        let allowed: Option<HashSet<String>> = allowed_domains.map(|domains| {
236            domains
237                .into_iter()
238                .map(|value| value.to_ascii_lowercase())
239                .collect()
240        });
241        let blocked: HashSet<String> = blocked_domains
242            .unwrap_or_default()
243            .into_iter()
244            .map(|value| value.to_ascii_lowercase())
245            .collect();
246
247        let link_re = Regex::new(r#"<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>(.*?)</a>"#)
248            .map_err(|e| {
249            ToolError::Execution(format!("Failed to compile link regex: {}", e))
250        })?;
251        let tag_re = Regex::new(r"(?is)<[^>]+>")
252            .map_err(|e| ToolError::Execution(format!("Failed to compile tag regex: {}", e)))?;
253        let snippet_re =
254            Regex::new(r#"<a[^>]*class="result__snippet"[^>]*href="[^"]*"[^>]*>(.*?)</a>"#)
255                .map_err(|e| {
256                    ToolError::Execution(format!("Failed to compile snippet regex: {}", e))
257                })?;
258
259        // Build a map of snippet content by href (to match with result links)
260        let href_re = Regex::new(r#"href="([^"]+)""#)
261            .map_err(|e| ToolError::Execution(format!("Failed to compile href regex: {}", e)))?;
262        let mut snippets: HashMap<String, String> = HashMap::new();
263        for cap in snippet_re.captures_iter(&html) {
264            if let Some(href_cap) = cap.get(0) {
265                let href_text = href_cap.as_str();
266                // Extract the href URL from the snippet anchor
267                if let Some(url_match) = href_re.find(href_text) {
268                    let raw_href = &href_text[url_match.start() + 6..url_match.end() - 1];
269                    if let Some(decoded) = Self::decode_duckduckgo_url(raw_href) {
270                        let snippet_text = cap
271                            .get(1)
272                            .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
273                            .unwrap_or_default();
274                        if !snippet_text.is_empty() {
275                            snippets.insert(decoded, snippet_text);
276                        }
277                    }
278                }
279            }
280        }
281
282        let mut results = Vec::new();
283        for capture in link_re.captures_iter(&html) {
284            let Some(raw_url) = capture.get(1).map(|m| m.as_str()) else {
285                continue;
286            };
287            let Some(url) = Self::decode_duckduckgo_url(raw_url) else {
288                continue;
289            };
290            let Some(host) = Self::host_of(&url) else {
291                continue;
292            };
293
294            if blocked
295                .iter()
296                .any(|blocked_domain| Self::domain_matches(&host, blocked_domain))
297            {
298                continue;
299            }
300            if let Some(allowed_set) = &allowed {
301                if !allowed_set
302                    .iter()
303                    .any(|allowed_domain| Self::domain_matches(&host, allowed_domain))
304                {
305                    continue;
306                }
307            }
308
309            let title = capture
310                .get(2)
311                .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
312                .unwrap_or_else(|| url.clone());
313
314            let snippet = snippets.get(&url).cloned().unwrap_or_default();
315
316            let mut result = json!({
317                "title": title,
318                "url": url,
319                "domain": host,
320            });
321            if !snippet.is_empty() {
322                result["snippet"] = json!(snippet);
323            }
324            results.push(result);
325
326            if results.len() >= max_results {
327                break;
328            }
329        }
330
331        ctx.emit_tool_token(format!(
332            "Found {} results for \"{}\"\n",
333            results.len(),
334            query
335        ))
336        .await;
337
338        let result_value = if results.is_empty() {
339            json!({
340                "query": parsed.query,
341                "results": [],
342                "note": "No results found for this query.",
343            })
344        } else {
345            json!({
346                "query": parsed.query,
347                "results": results,
348            })
349        };
350
351        // Store in cache
352        Self::put_cache(cache_key, result_value.clone());
353
354        let mut result_string = result_value.to_string();
355        result_string.push_str("\n\nREMINDER: You MUST include a Sources section at the end of your response, listing all relevant URLs as markdown hyperlinks: [Title](URL)");
356
357        Ok(ToolResult {
358            success: true,
359            result: result_string,
360            display_preference: Some("Collapsible".to_string()),
361        })
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn domain_matches_supports_subdomains() {
371        assert!(WebSearchTool::domain_matches("example.com", "example.com"));
372        assert!(WebSearchTool::domain_matches(
373            "docs.example.com",
374            "example.com"
375        ));
376        assert!(!WebSearchTool::domain_matches(
377            "notexample.com",
378            "example.com"
379        ));
380        assert!(!WebSearchTool::domain_matches(
381            "evil-example.com",
382            "example.com"
383        ));
384    }
385
386    #[test]
387    fn host_of_normalizes_case() {
388        let host = WebSearchTool::host_of("https://Docs.Example.Com/path").unwrap();
389        assert_eq!(host, "docs.example.com");
390    }
391
392    #[test]
393    fn decode_duckduckgo_url_extracts_uddg_param() {
394        let raw = "https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage&rut=whatever";
395        let decoded = WebSearchTool::decode_duckduckgo_url(raw).unwrap();
396        assert_eq!(decoded, "https://example.com/page");
397    }
398
399    #[test]
400    fn cache_key_is_stable() {
401        let k1 =
402            WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
403        let k2 =
404            WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
405        assert_eq!(k1, k2);
406
407        let k3 = WebSearchTool::cache_key("rust", &None, &Some(vec!["bad.com".to_string()]));
408        assert_ne!(k1, k3);
409    }
410}