Skip to main content

bamboo_tools/tools/
web_search.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use parking_lot::RwLock;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::collections::{HashMap, HashSet};
8use std::sync::LazyLock;
9use std::time::{Duration, Instant};
10
11const CACHE_TTL: Duration = Duration::from_secs(15 * 60);
12const DEFAULT_MAX_RESULTS: usize = 10;
13const ABSOLUTE_MAX_RESULTS: usize = 20;
14const USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
15
16#[derive(Debug, Deserialize)]
17struct WebSearchArgs {
18    query: String,
19    #[serde(default)]
20    allowed_domains: Option<Vec<String>>,
21    #[serde(default)]
22    blocked_domains: Option<Vec<String>>,
23    #[serde(default)]
24    max_results: Option<usize>,
25}
26
27struct CachedSearch {
28    results: serde_json::Value,
29    expires_at: Instant,
30}
31
32static SEARCH_CACHE: LazyLock<RwLock<HashMap<String, CachedSearch>>> =
33    LazyLock::new(|| RwLock::new(HashMap::new()));
34
35pub struct WebSearchTool;
36
37impl WebSearchTool {
38    pub fn new() -> Self {
39        Self
40    }
41
42    fn cache_key(
43        query: &str,
44        allowed: &Option<Vec<String>>,
45        blocked: &Option<Vec<String>>,
46    ) -> String {
47        let mut key = query.to_string();
48        if let Some(domains) = allowed {
49            key.push('|');
50            key.push_str(&domains.join(","));
51        }
52        key.push('|');
53        if let Some(domains) = blocked {
54            key.push_str(&domains.join(","));
55        }
56        key
57    }
58
59    fn try_cache(key: &str) -> Option<serde_json::Value> {
60        let cache = SEARCH_CACHE.read();
61        let entry = cache.get(key)?;
62        if entry.expires_at > Instant::now() {
63            Some(entry.results.clone())
64        } else {
65            None
66        }
67    }
68
69    fn put_cache(key: String, results: serde_json::Value) {
70        let mut cache = SEARCH_CACHE.write();
71        cache.insert(
72            key,
73            CachedSearch {
74                results,
75                expires_at: Instant::now() + CACHE_TTL,
76            },
77        );
78    }
79
80    fn decode_duckduckgo_url(raw: &str) -> Option<String> {
81        if let Ok(url) = url::Url::parse(raw) {
82            if let Some(value) = url
83                .query_pairs()
84                .find(|(key, _)| key == "uddg")
85                .map(|(_, value)| value.to_string())
86            {
87                return Some(value);
88            }
89        }
90
91        Some(raw.to_string())
92    }
93
94    fn host_of(url: &str) -> Option<String> {
95        url::Url::parse(url)
96            .ok()
97            .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
98    }
99
100    fn domain_matches(host: &str, domain: &str) -> bool {
101        host == domain || host.ends_with(&format!(".{}", domain))
102    }
103}
104
105impl Default for WebSearchTool {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[async_trait]
112impl Tool for WebSearchTool {
113    fn name(&self) -> &str {
114        "WebSearch"
115    }
116
117    fn description(&self) -> &str {
118        "Search DuckDuckGo and return up to 10 filtered results (title, url, domain, snippet) with optional allow/block domain filters."
119    }
120
121    fn mutability(&self) -> crate::ToolMutability {
122        crate::ToolMutability::ReadOnly
123    }
124
125    fn concurrency_safe(&self) -> bool {
126        true
127    }
128
129    fn parameters_schema(&self) -> serde_json::Value {
130        json!({
131            "type": "object",
132            "properties": {
133                "query": {
134                    "type": "string",
135                    "minLength": 2,
136                    "description": "The search query to use"
137                },
138                "allowed_domains": {
139                    "type": "array",
140                    "items": { "type": "string" },
141                    "description": "Only include results from these domains"
142                },
143                "blocked_domains": {
144                    "type": "array",
145                    "items": { "type": "string" },
146                    "description": "Never include results from these domains"
147                },
148                "max_results": {
149                    "type": "number",
150                    "description": "Maximum results to return (default 10, max 20)"
151                }
152            },
153            "required": ["query"],
154            "additionalProperties": false
155        })
156    }
157
158    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
159        self.execute_with_context(args, ToolExecutionContext::none("WebSearch"))
160            .await
161    }
162
163    async fn execute_with_context(
164        &self,
165        args: serde_json::Value,
166        ctx: ToolExecutionContext<'_>,
167    ) -> Result<ToolResult, ToolError> {
168        let parsed: WebSearchArgs = serde_json::from_value(args)
169            .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebSearch args: {}", e)))?;
170
171        let query = parsed.query.trim();
172        if query.len() < 2 {
173            return Err(ToolError::InvalidArguments(
174                "query must be at least 2 characters".to_string(),
175            ));
176        }
177
178        let allowed_domains = parsed.allowed_domains.filter(|v| !v.is_empty());
179        let blocked_domains = parsed.blocked_domains.filter(|v| !v.is_empty());
180
181        // Mutual-exclusion validation
182        if allowed_domains.is_some() && blocked_domains.is_some() {
183            return Err(ToolError::InvalidArguments(
184                "Cannot specify both allowed_domains and blocked_domains in the same request"
185                    .to_string(),
186            ));
187        }
188
189        let max_results = parsed
190            .max_results
191            .unwrap_or(DEFAULT_MAX_RESULTS)
192            .min(ABSOLUTE_MAX_RESULTS);
193
194        // Check cache
195        let cache_key = Self::cache_key(query, &allowed_domains, &blocked_domains);
196        if let Some(cached) = Self::try_cache(&cache_key) {
197            ctx.emit_tool_token("Using cached search results\n").await;
198            return Ok(ToolResult {
199                success: true,
200                result: cached.to_string(),
201                display_preference: Some("Collapsible".to_string()),
202            });
203        }
204
205        ctx.emit_tool_token(format!("Searching: {}\n", query)).await;
206
207        let client = reqwest::Client::builder()
208            .timeout(Duration::from_secs(30))
209            .build()
210            .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
211
212        let response = client
213            .get("https://duckduckgo.com/html/")
214            .header("User-Agent", USER_AGENT)
215            .query(&[("q", query)])
216            .send()
217            .await
218            .map_err(|e| ToolError::Execution(format!("Web search request failed: {}", e)))?;
219
220        let html = response.text().await.map_err(|e| {
221            ToolError::Execution(format!("Failed to decode web search body: {}", e))
222        })?;
223
224        // Detect anti-bot page
225        if html.contains("Unfortunately, bots use DuckDuckGo too") || html.contains("anomaly-modal")
226        {
227            return Err(ToolError::Execution(
228                "Search blocked by anti-bot protection. Please retry.".to_string(),
229            ));
230        }
231
232        let allowed: Option<HashSet<String>> = allowed_domains.map(|domains| {
233            domains
234                .into_iter()
235                .map(|value| value.to_ascii_lowercase())
236                .collect()
237        });
238        let blocked: HashSet<String> = blocked_domains
239            .unwrap_or_default()
240            .into_iter()
241            .map(|value| value.to_ascii_lowercase())
242            .collect();
243
244        let link_re = Regex::new(r#"<a[^>]*class="result__a"[^>]*href="([^"]+)"[^>]*>(.*?)</a>"#)
245            .map_err(|e| {
246            ToolError::Execution(format!("Failed to compile link regex: {}", e))
247        })?;
248        let tag_re = Regex::new(r"(?is)<[^>]+>")
249            .map_err(|e| ToolError::Execution(format!("Failed to compile tag regex: {}", e)))?;
250        let snippet_re =
251            Regex::new(r#"<a[^>]*class="result__snippet"[^>]*href="[^"]*"[^>]*>(.*?)</a>"#)
252                .map_err(|e| {
253                    ToolError::Execution(format!("Failed to compile snippet regex: {}", e))
254                })?;
255
256        // Build a map of snippet content by href (to match with result links)
257        let href_re = Regex::new(r#"href="([^"]+)""#)
258            .map_err(|e| ToolError::Execution(format!("Failed to compile href regex: {}", e)))?;
259        let mut snippets: HashMap<String, String> = HashMap::new();
260        for cap in snippet_re.captures_iter(&html) {
261            if let Some(href_cap) = cap.get(0) {
262                let href_text = href_cap.as_str();
263                // Extract the href URL from the snippet anchor
264                if let Some(url_match) = href_re.find(href_text) {
265                    let raw_href = &href_text[url_match.start() + 6..url_match.end() - 1];
266                    if let Some(decoded) = Self::decode_duckduckgo_url(raw_href) {
267                        let snippet_text = cap
268                            .get(1)
269                            .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
270                            .unwrap_or_default();
271                        if !snippet_text.is_empty() {
272                            snippets.insert(decoded, snippet_text);
273                        }
274                    }
275                }
276            }
277        }
278
279        let mut results = Vec::new();
280        for capture in link_re.captures_iter(&html) {
281            let Some(raw_url) = capture.get(1).map(|m| m.as_str()) else {
282                continue;
283            };
284            let Some(url) = Self::decode_duckduckgo_url(raw_url) else {
285                continue;
286            };
287            let Some(host) = Self::host_of(&url) else {
288                continue;
289            };
290
291            if blocked
292                .iter()
293                .any(|blocked_domain| Self::domain_matches(&host, blocked_domain))
294            {
295                continue;
296            }
297            if let Some(allowed_set) = &allowed {
298                if !allowed_set
299                    .iter()
300                    .any(|allowed_domain| Self::domain_matches(&host, allowed_domain))
301                {
302                    continue;
303                }
304            }
305
306            let title = capture
307                .get(2)
308                .map(|m| tag_re.replace_all(m.as_str(), "").trim().to_string())
309                .unwrap_or_else(|| url.clone());
310
311            let snippet = snippets.get(&url).cloned().unwrap_or_default();
312
313            let mut result = json!({
314                "title": title,
315                "url": url,
316                "domain": host,
317            });
318            if !snippet.is_empty() {
319                result["snippet"] = json!(snippet);
320            }
321            results.push(result);
322
323            if results.len() >= max_results {
324                break;
325            }
326        }
327
328        ctx.emit_tool_token(format!(
329            "Found {} results for \"{}\"\n",
330            results.len(),
331            query
332        ))
333        .await;
334
335        let result_value = if results.is_empty() {
336            json!({
337                "query": parsed.query,
338                "results": [],
339                "note": "No results found for this query.",
340            })
341        } else {
342            json!({
343                "query": parsed.query,
344                "results": results,
345            })
346        };
347
348        // Store in cache
349        Self::put_cache(cache_key, result_value.clone());
350
351        let mut result_string = result_value.to_string();
352        result_string.push_str("\n\nREMINDER: You MUST include a Sources section at the end of your response, listing all relevant URLs as markdown hyperlinks: [Title](URL)");
353
354        Ok(ToolResult {
355            success: true,
356            result: result_string,
357            display_preference: Some("Collapsible".to_string()),
358        })
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn domain_matches_supports_subdomains() {
368        assert!(WebSearchTool::domain_matches("example.com", "example.com"));
369        assert!(WebSearchTool::domain_matches(
370            "docs.example.com",
371            "example.com"
372        ));
373        assert!(!WebSearchTool::domain_matches(
374            "notexample.com",
375            "example.com"
376        ));
377        assert!(!WebSearchTool::domain_matches(
378            "evil-example.com",
379            "example.com"
380        ));
381    }
382
383    #[test]
384    fn host_of_normalizes_case() {
385        let host = WebSearchTool::host_of("https://Docs.Example.Com/path").unwrap();
386        assert_eq!(host, "docs.example.com");
387    }
388
389    #[test]
390    fn decode_duckduckgo_url_extracts_uddg_param() {
391        let raw = "https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage&rut=whatever";
392        let decoded = WebSearchTool::decode_duckduckgo_url(raw).unwrap();
393        assert_eq!(decoded, "https://example.com/page");
394    }
395
396    #[test]
397    fn cache_key_is_stable() {
398        let k1 =
399            WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
400        let k2 =
401            WebSearchTool::cache_key("rust", &Some(vec!["doc.rust-lang.org".to_string()]), &None);
402        assert_eq!(k1, k2);
403
404        let k3 = WebSearchTool::cache_key("rust", &None, &Some(vec!["bad.com".to_string()]));
405        assert_ne!(k1, k3);
406    }
407}