Skip to main content

bamboo_tools/tools/
web_search.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolResult};
3use regex::Regex;
4use serde::Deserialize;
5use serde_json::json;
6use std::collections::HashSet;
7use std::time::Duration;
8
9#[derive(Debug, Deserialize)]
10struct WebSearchArgs {
11    query: String,
12    #[serde(default)]
13    allowed_domains: Option<Vec<String>>,
14    #[serde(default)]
15    blocked_domains: Option<Vec<String>>,
16}
17
18pub struct WebSearchTool;
19
20impl WebSearchTool {
21    pub fn new() -> Self {
22        Self
23    }
24
25    fn decode_duckduckgo_url(raw: &str) -> Option<String> {
26        if let Ok(url) = url::Url::parse(raw) {
27            if let Some(value) = url
28                .query_pairs()
29                .find(|(key, _)| key == "uddg")
30                .map(|(_, value)| value.to_string())
31            {
32                return Some(value);
33            }
34        }
35
36        Some(raw.to_string())
37    }
38
39    fn host_of(url: &str) -> Option<String> {
40        url::Url::parse(url)
41            .ok()
42            .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
43    }
44
45    fn domain_matches(host: &str, domain: &str) -> bool {
46        host == domain || host.ends_with(&format!(".{}", domain))
47    }
48}
49
50impl Default for WebSearchTool {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56#[async_trait]
57impl Tool for WebSearchTool {
58    fn name(&self) -> &str {
59        "WebSearch"
60    }
61
62    fn description(&self) -> &str {
63        "Search DuckDuckGo and return up to 10 filtered results (title, url, domain) with optional allow/block domain filters."
64    }
65
66    fn mutability(&self) -> crate::ToolMutability {
67        crate::ToolMutability::ReadOnly
68    }
69
70    fn concurrency_safe(&self) -> bool {
71        true
72    }
73
74    fn parameters_schema(&self) -> serde_json::Value {
75        json!({
76            "type": "object",
77            "properties": {
78                "query": {
79                    "type": "string",
80                    "minLength": 2,
81                    "description": "The search query to use"
82                },
83                "allowed_domains": {
84                    "type": "array",
85                    "items": { "type": "string" },
86                    "description": "Only include results from these domains"
87                },
88                "blocked_domains": {
89                    "type": "array",
90                    "items": { "type": "string" },
91                    "description": "Never include results from these domains"
92                }
93            },
94            "required": ["query"],
95            "additionalProperties": false
96        })
97    }
98
99    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
100        let parsed: WebSearchArgs = serde_json::from_value(args)
101            .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebSearch args: {}", e)))?;
102
103        let client = reqwest::Client::builder()
104            .timeout(Duration::from_secs(30))
105            .build()
106            .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
107
108        let response = client
109            .get("https://duckduckgo.com/html/")
110            .query(&[("q", parsed.query.trim())])
111            .send()
112            .await
113            .map_err(|e| ToolError::Execution(format!("Web search request failed: {}", e)))?;
114
115        let html = response.text().await.map_err(|e| {
116            ToolError::Execution(format!("Failed to decode web search body: {}", e))
117        })?;
118
119        let allowed: Option<HashSet<String>> = parsed.allowed_domains.map(|domains| {
120            domains
121                .into_iter()
122                .map(|value| value.to_ascii_lowercase())
123                .collect()
124        });
125        let blocked: HashSet<String> = parsed
126            .blocked_domains
127            .unwrap_or_default()
128            .into_iter()
129            .map(|value| value.to_ascii_lowercase())
130            .collect();
131
132        let link_re =
133            Regex::new(r#"<a[^>]*class=\"result__a\"[^>]*href=\"([^\"]+)\"[^>]*>(.*?)</a>"#)
134                .map_err(|e| {
135                    ToolError::Execution(format!("Failed to compile parser regex: {}", e))
136                })?;
137        let tag_re = Regex::new(r"(?is)<[^>]+>")
138            .map_err(|e| ToolError::Execution(format!("Failed to compile tag regex: {}", e)))?;
139
140        let mut results = Vec::new();
141        for capture in link_re.captures_iter(&html) {
142            let Some(raw_url) = capture.get(1).map(|m| m.as_str()) else {
143                continue;
144            };
145            let Some(url) = Self::decode_duckduckgo_url(raw_url) else {
146                continue;
147            };
148            let Some(host) = Self::host_of(&url) else {
149                continue;
150            };
151
152            if blocked
153                .iter()
154                .any(|blocked_domain| Self::domain_matches(&host, blocked_domain))
155            {
156                continue;
157            }
158            if let Some(allowed_set) = &allowed {
159                if !allowed_set
160                    .iter()
161                    .any(|allowed_domain| Self::domain_matches(&host, allowed_domain))
162                {
163                    continue;
164                }
165            }
166
167            let title = capture
168                .get(2)
169                .map(|m| tag_re.replace_all(m.as_str(), "").to_string())
170                .unwrap_or_else(|| url.clone());
171
172            results.push(json!({
173                "title": title,
174                "url": url,
175                "domain": host,
176            }));
177
178            if results.len() >= 10 {
179                break;
180            }
181        }
182
183        Ok(ToolResult {
184            success: true,
185            result: json!({
186                "query": parsed.query,
187                "results": results,
188            })
189            .to_string(),
190            display_preference: Some("Collapsible".to_string()),
191        })
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn domain_matches_supports_subdomains() {
201        assert!(WebSearchTool::domain_matches("example.com", "example.com"));
202        assert!(WebSearchTool::domain_matches(
203            "docs.example.com",
204            "example.com"
205        ));
206        assert!(!WebSearchTool::domain_matches(
207            "notexample.com",
208            "example.com"
209        ));
210        assert!(!WebSearchTool::domain_matches(
211            "evil-example.com",
212            "example.com"
213        ));
214    }
215
216    #[test]
217    fn host_of_normalizes_case() {
218        let host = WebSearchTool::host_of("https://Docs.Example.Com/path").unwrap();
219        assert_eq!(host, "docs.example.com");
220    }
221
222    #[test]
223    fn decode_duckduckgo_url_extracts_uddg_param() {
224        let raw = "https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage&rut=whatever";
225        let decoded = WebSearchTool::decode_duckduckgo_url(raw).unwrap();
226        assert_eq!(decoded, "https://example.com/page");
227    }
228}