1use agentzero_core::{Tool, ToolContext, ToolResult};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use serde::Deserialize;
5use std::time::Duration;
6use url::form_urlencoded;
7
8const MAX_RESULTS_CAP: usize = 10;
9const DEFAULT_TIMEOUT_SECS: u64 = 15;
10
11#[derive(Debug, Deserialize)]
12struct WebSearchInput {
13 query: String,
14 #[serde(default = "default_max_results")]
15 max_results: usize,
16 #[serde(default)]
17 provider: Option<String>,
18}
19
20fn default_max_results() -> usize {
21 5
22}
23
24#[derive(Debug, Clone)]
25pub struct WebSearchConfig {
26 pub provider: String,
27 pub brave_api_key: Option<String>,
28 pub jina_api_key: Option<String>,
29 pub timeout_secs: u64,
30 pub user_agent: String,
31}
32
33impl Default for WebSearchConfig {
34 fn default() -> Self {
35 Self {
36 provider: "duckduckgo".to_string(),
37 brave_api_key: None,
38 jina_api_key: None,
39 timeout_secs: DEFAULT_TIMEOUT_SECS,
40 user_agent: "AgentZero/1.0".to_string(),
41 }
42 }
43}
44
45pub struct WebSearchTool {
46 client: reqwest::Client,
47 config: WebSearchConfig,
48}
49
50impl Default for WebSearchTool {
51 fn default() -> Self {
52 Self::new(WebSearchConfig::default())
53 }
54}
55
56impl WebSearchTool {
57 pub fn new(config: WebSearchConfig) -> Self {
58 let client = reqwest::Client::builder()
59 .timeout(Duration::from_secs(config.timeout_secs))
60 .user_agent(&config.user_agent)
61 .build()
62 .unwrap_or_default();
63 Self { client, config }
64 }
65
66 async fn search_duckduckgo(&self, query: &str, max_results: usize) -> anyhow::Result<String> {
67 let url = format!(
68 "https://html.duckduckgo.com/html/?q={}",
69 form_urlencoded::byte_serialize(query.as_bytes()).collect::<String>()
70 );
71 let response = self
72 .client
73 .get(&url)
74 .send()
75 .await
76 .context("DuckDuckGo request failed")?;
77 let body = response
78 .text()
79 .await
80 .context("failed reading DuckDuckGo response")?;
81
82 let mut results = Vec::new();
83 for (i, chunk) in body.split("class=\"result__a\"").skip(1).enumerate() {
84 if i >= max_results {
85 break;
86 }
87 let title = extract_between(chunk, ">", "</a>").unwrap_or_default();
88 let href = extract_between(chunk, "href=\"", "\"").unwrap_or_default();
89 let snippet = if let Some(snip_chunk) = chunk.split("class=\"result__snippet\"").nth(1)
90 {
91 extract_between(snip_chunk, ">", "</")
92 .unwrap_or_default()
93 .replace("&", "&")
94 .replace("<", "<")
95 .replace(">", ">")
96 .replace(""", "\"")
97 .replace("<b>", "")
98 .replace("</b>", "")
99 } else {
100 String::new()
101 };
102 results.push(format!(
103 "{}. {}\n {}\n {}",
104 i + 1,
105 clean_html(title),
106 href,
107 clean_html(&snippet)
108 ));
109 }
110
111 if results.is_empty() {
112 Ok("no results found".to_string())
113 } else {
114 Ok(results.join("\n\n"))
115 }
116 }
117
118 async fn search_brave(
119 &self,
120 query: &str,
121 max_results: usize,
122 api_key: &str,
123 ) -> anyhow::Result<String> {
124 let url = format!(
125 "https://api.search.brave.com/res/v1/web/search?q={}&count={}",
126 form_urlencoded::byte_serialize(query.as_bytes()).collect::<String>(),
127 max_results.min(MAX_RESULTS_CAP)
128 );
129 let response = self
130 .client
131 .get(&url)
132 .header("X-Subscription-Token", api_key)
133 .header("Accept", "application/json")
134 .send()
135 .await
136 .context("Brave search request failed")?;
137
138 if !response.status().is_success() {
139 let status = response.status();
140 let body = response.text().await.unwrap_or_default();
141 anyhow::bail!("Brave API returned HTTP {status}: {body}");
142 }
143
144 let body: serde_json::Value = response
145 .json()
146 .await
147 .context("failed parsing Brave response")?;
148 let mut results = Vec::new();
149 if let Some(web) = body
150 .get("web")
151 .and_then(|w| w.get("results"))
152 .and_then(|r| r.as_array())
153 {
154 for (i, item) in web.iter().enumerate().take(max_results) {
155 let title = item.get("title").and_then(|v| v.as_str()).unwrap_or("");
156 let url = item.get("url").and_then(|v| v.as_str()).unwrap_or("");
157 let desc = item
158 .get("description")
159 .and_then(|v| v.as_str())
160 .unwrap_or("");
161 results.push(format!("{}. {}\n {}\n {}", i + 1, title, url, desc));
162 }
163 }
164
165 if results.is_empty() {
166 Ok("no results found".to_string())
167 } else {
168 Ok(results.join("\n\n"))
169 }
170 }
171
172 async fn search_jina(
173 &self,
174 query: &str,
175 max_results: usize,
176 api_key: &str,
177 ) -> anyhow::Result<String> {
178 let url = format!(
179 "https://s.jina.ai/{}",
180 form_urlencoded::byte_serialize(query.as_bytes()).collect::<String>()
181 );
182 let response = self
183 .client
184 .get(&url)
185 .header("Authorization", format!("Bearer {api_key}"))
186 .header("Accept", "application/json")
187 .send()
188 .await
189 .context("Jina search request failed")?;
190
191 if !response.status().is_success() {
192 let status = response.status();
193 let body = response.text().await.unwrap_or_default();
194 anyhow::bail!("Jina API returned HTTP {status}: {body}");
195 }
196
197 let body: serde_json::Value = response
198 .json()
199 .await
200 .context("failed parsing Jina response")?;
201 let mut results = Vec::new();
202 if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
203 for (i, item) in data.iter().enumerate().take(max_results) {
204 let title = item.get("title").and_then(|v| v.as_str()).unwrap_or("");
205 let url = item.get("url").and_then(|v| v.as_str()).unwrap_or("");
206 let desc = item
207 .get("description")
208 .and_then(|v| v.as_str())
209 .unwrap_or("");
210 results.push(format!("{}. {}\n {}\n {}", i + 1, title, url, desc));
211 }
212 }
213
214 if results.is_empty() {
215 Ok("no results found".to_string())
216 } else {
217 Ok(results.join("\n\n"))
218 }
219 }
220}
221
222fn extract_between<'a>(text: &'a str, start: &str, end: &str) -> Option<&'a str> {
223 let s = text.find(start)? + start.len();
224 let e = text[s..].find(end)? + s;
225 Some(&text[s..e])
226}
227
228fn clean_html(s: &str) -> String {
229 let mut out = String::with_capacity(s.len());
230 let mut in_tag = false;
231 for ch in s.chars() {
232 if ch == '<' {
233 in_tag = true;
234 } else if ch == '>' {
235 in_tag = false;
236 } else if !in_tag {
237 out.push(ch);
238 }
239 }
240 out.replace("&", "&")
241 .replace("<", "<")
242 .replace(">", ">")
243 .replace(""", "\"")
244}
245
246#[async_trait]
247impl Tool for WebSearchTool {
248 fn name(&self) -> &'static str {
249 "web_search"
250 }
251
252 fn description(&self) -> &'static str {
253 "Search the web using DuckDuckGo, Brave, or Jina and return a summary of results."
254 }
255
256 fn input_schema(&self) -> Option<serde_json::Value> {
257 Some(serde_json::json!({
258 "type": "object",
259 "properties": {
260 "query": {
261 "type": "string",
262 "description": "The search query"
263 }
264 },
265 "required": ["query"]
266 }))
267 }
268
269 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
270 let req: WebSearchInput =
271 serde_json::from_str(input).context("web_search expects JSON: {\"query\": \"...\"}")?;
272
273 if req.query.trim().is_empty() {
274 return Err(anyhow!("query must not be empty"));
275 }
276
277 let max = req.max_results.clamp(1, MAX_RESULTS_CAP);
278 let provider = req.provider.as_deref().unwrap_or(&self.config.provider);
279
280 let brave_env_key = std::env::var("BRAVE_API_KEY").ok();
281 let jina_env_key = std::env::var("JINA_API_KEY").ok();
282
283 let output = match provider {
284 "brave" => {
285 let key = self
286 .config
287 .brave_api_key
288 .as_deref()
289 .or(brave_env_key.as_deref())
290 .ok_or_else(|| {
291 anyhow!("brave_api_key is required for Brave search provider")
292 })?;
293 self.search_brave(&req.query, max, key).await?
294 }
295 "jina" => {
296 let key = self
297 .config
298 .jina_api_key
299 .as_deref()
300 .or(jina_env_key.as_deref())
301 .ok_or_else(|| anyhow!("jina_api_key is required for Jina search provider"))?;
302 self.search_jina(&req.query, max, key).await?
303 }
304 _ => self.search_duckduckgo(&req.query, max).await?,
305 };
306
307 Ok(ToolResult { output })
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[tokio::test]
316 async fn web_search_rejects_empty_query() {
317 let tool = WebSearchTool::default();
318 let err = tool
319 .execute(r#"{"query": ""}"#, &ToolContext::new(".".to_string()))
320 .await
321 .expect_err("empty query should fail");
322 assert!(err.to_string().contains("query must not be empty"));
323 }
324
325 #[tokio::test]
326 async fn web_search_rejects_invalid_json() {
327 let tool = WebSearchTool::default();
328 let err = tool
329 .execute("not json", &ToolContext::new(".".to_string()))
330 .await
331 .expect_err("invalid JSON should fail");
332 assert!(err.to_string().contains("web_search expects JSON"));
333 }
334
335 #[tokio::test]
336 async fn web_search_brave_requires_api_key() {
337 let tool = WebSearchTool::new(WebSearchConfig {
338 provider: "brave".to_string(),
339 brave_api_key: None,
340 ..Default::default()
341 });
342 let err = tool
343 .execute(r#"{"query": "test"}"#, &ToolContext::new(".".to_string()))
344 .await
345 .expect_err("missing API key should fail");
346 assert!(err.to_string().contains("brave_api_key"));
347 }
348
349 #[test]
350 fn clean_html_strips_tags() {
351 assert_eq!(clean_html("<b>hello</b> world"), "hello world");
352 assert_eq!(clean_html("no tags"), "no tags");
353 }
354
355 #[test]
356 fn extract_between_works() {
357 assert_eq!(extract_between("foo=bar;baz", "=", ";"), Some("bar"));
358 assert_eq!(extract_between("nothing", "=", ";"), None);
359 }
360}