bamboo_tools/tools/
web_search.rs1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolResult};
3use regex::Regex;
4use serde::Deserialize;
5use serde_json::json;
6use std::collections::HashSet;
7use std::time::Duration;
8
9#[derive(Debug, Deserialize)]
10struct WebSearchArgs {
11 query: String,
12 #[serde(default)]
13 allowed_domains: Option<Vec<String>>,
14 #[serde(default)]
15 blocked_domains: Option<Vec<String>>,
16}
17
18pub struct WebSearchTool;
19
20impl WebSearchTool {
21 pub fn new() -> Self {
22 Self
23 }
24
25 fn decode_duckduckgo_url(raw: &str) -> Option<String> {
26 if let Ok(url) = url::Url::parse(raw) {
27 if let Some(value) = url
28 .query_pairs()
29 .find(|(key, _)| key == "uddg")
30 .map(|(_, value)| value.to_string())
31 {
32 return Some(value);
33 }
34 }
35
36 Some(raw.to_string())
37 }
38
39 fn host_of(url: &str) -> Option<String> {
40 url::Url::parse(url)
41 .ok()
42 .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
43 }
44
45 fn domain_matches(host: &str, domain: &str) -> bool {
46 host == domain || host.ends_with(&format!(".{}", domain))
47 }
48}
49
50impl Default for WebSearchTool {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56#[async_trait]
57impl Tool for WebSearchTool {
58 fn name(&self) -> &str {
59 "WebSearch"
60 }
61
62 fn description(&self) -> &str {
63 "Search DuckDuckGo and return up to 10 filtered results (title, url, domain) with optional allow/block domain filters."
64 }
65
66 fn mutability(&self) -> crate::ToolMutability {
67 crate::ToolMutability::ReadOnly
68 }
69
70 fn concurrency_safe(&self) -> bool {
71 true
72 }
73
74 fn parameters_schema(&self) -> serde_json::Value {
75 json!({
76 "type": "object",
77 "properties": {
78 "query": {
79 "type": "string",
80 "minLength": 2,
81 "description": "The search query to use"
82 },
83 "allowed_domains": {
84 "type": "array",
85 "items": { "type": "string" },
86 "description": "Only include results from these domains"
87 },
88 "blocked_domains": {
89 "type": "array",
90 "items": { "type": "string" },
91 "description": "Never include results from these domains"
92 }
93 },
94 "required": ["query"],
95 "additionalProperties": false
96 })
97 }
98
99 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
100 let parsed: WebSearchArgs = serde_json::from_value(args)
101 .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebSearch args: {}", e)))?;
102
103 let client = reqwest::Client::builder()
104 .timeout(Duration::from_secs(30))
105 .build()
106 .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
107
108 let response = client
109 .get("https://duckduckgo.com/html/")
110 .query(&[("q", parsed.query.trim())])
111 .send()
112 .await
113 .map_err(|e| ToolError::Execution(format!("Web search request failed: {}", e)))?;
114
115 let html = response.text().await.map_err(|e| {
116 ToolError::Execution(format!("Failed to decode web search body: {}", e))
117 })?;
118
119 let allowed: Option<HashSet<String>> = parsed.allowed_domains.map(|domains| {
120 domains
121 .into_iter()
122 .map(|value| value.to_ascii_lowercase())
123 .collect()
124 });
125 let blocked: HashSet<String> = parsed
126 .blocked_domains
127 .unwrap_or_default()
128 .into_iter()
129 .map(|value| value.to_ascii_lowercase())
130 .collect();
131
132 let link_re =
133 Regex::new(r#"<a[^>]*class=\"result__a\"[^>]*href=\"([^\"]+)\"[^>]*>(.*?)</a>"#)
134 .map_err(|e| {
135 ToolError::Execution(format!("Failed to compile parser regex: {}", e))
136 })?;
137 let tag_re = Regex::new(r"(?is)<[^>]+>")
138 .map_err(|e| ToolError::Execution(format!("Failed to compile tag regex: {}", e)))?;
139
140 let mut results = Vec::new();
141 for capture in link_re.captures_iter(&html) {
142 let Some(raw_url) = capture.get(1).map(|m| m.as_str()) else {
143 continue;
144 };
145 let Some(url) = Self::decode_duckduckgo_url(raw_url) else {
146 continue;
147 };
148 let Some(host) = Self::host_of(&url) else {
149 continue;
150 };
151
152 if blocked
153 .iter()
154 .any(|blocked_domain| Self::domain_matches(&host, blocked_domain))
155 {
156 continue;
157 }
158 if let Some(allowed_set) = &allowed {
159 if !allowed_set
160 .iter()
161 .any(|allowed_domain| Self::domain_matches(&host, allowed_domain))
162 {
163 continue;
164 }
165 }
166
167 let title = capture
168 .get(2)
169 .map(|m| tag_re.replace_all(m.as_str(), "").to_string())
170 .unwrap_or_else(|| url.clone());
171
172 results.push(json!({
173 "title": title,
174 "url": url,
175 "domain": host,
176 }));
177
178 if results.len() >= 10 {
179 break;
180 }
181 }
182
183 Ok(ToolResult {
184 success: true,
185 result: json!({
186 "query": parsed.query,
187 "results": results,
188 })
189 .to_string(),
190 display_preference: Some("Collapsible".to_string()),
191 })
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn domain_matches_supports_subdomains() {
201 assert!(WebSearchTool::domain_matches("example.com", "example.com"));
202 assert!(WebSearchTool::domain_matches(
203 "docs.example.com",
204 "example.com"
205 ));
206 assert!(!WebSearchTool::domain_matches(
207 "notexample.com",
208 "example.com"
209 ));
210 assert!(!WebSearchTool::domain_matches(
211 "evil-example.com",
212 "example.com"
213 ));
214 }
215
216 #[test]
217 fn host_of_normalizes_case() {
218 let host = WebSearchTool::host_of("https://Docs.Example.Com/path").unwrap();
219 assert_eq!(host, "docs.example.com");
220 }
221
222 #[test]
223 fn decode_duckduckgo_url_extracts_uddg_param() {
224 let raw = "https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage&rut=whatever";
225 let decoded = WebSearchTool::decode_duckduckgo_url(raw).unwrap();
226 assert_eq!(decoded, "https://example.com/page");
227 }
228}