Skip to main content

agentzero_tools/
web_search.rs

1use agentzero_core::{Tool, ToolContext, ToolResult};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use serde::Deserialize;
5use std::time::Duration;
6use url::form_urlencoded;
7
8const MAX_RESULTS_CAP: usize = 10;
9const DEFAULT_TIMEOUT_SECS: u64 = 15;
10
11#[derive(Debug, Deserialize)]
12struct WebSearchInput {
13    query: String,
14    #[serde(default = "default_max_results")]
15    max_results: usize,
16    #[serde(default)]
17    provider: Option<String>,
18}
19
20fn default_max_results() -> usize {
21    5
22}
23
24#[derive(Debug, Clone)]
25pub struct WebSearchConfig {
26    pub provider: String,
27    pub brave_api_key: Option<String>,
28    pub jina_api_key: Option<String>,
29    pub timeout_secs: u64,
30    pub user_agent: String,
31}
32
33impl Default for WebSearchConfig {
34    fn default() -> Self {
35        Self {
36            provider: "duckduckgo".to_string(),
37            brave_api_key: None,
38            jina_api_key: None,
39            timeout_secs: DEFAULT_TIMEOUT_SECS,
40            user_agent: "AgentZero/1.0".to_string(),
41        }
42    }
43}
44
45pub struct WebSearchTool {
46    client: reqwest::Client,
47    config: WebSearchConfig,
48}
49
50impl Default for WebSearchTool {
51    fn default() -> Self {
52        Self::new(WebSearchConfig::default())
53    }
54}
55
56impl WebSearchTool {
57    pub fn new(config: WebSearchConfig) -> Self {
58        let client = reqwest::Client::builder()
59            .timeout(Duration::from_secs(config.timeout_secs))
60            .user_agent(&config.user_agent)
61            .build()
62            .unwrap_or_default();
63        Self { client, config }
64    }
65
66    async fn search_duckduckgo(&self, query: &str, max_results: usize) -> anyhow::Result<String> {
67        let url = format!(
68            "https://html.duckduckgo.com/html/?q={}",
69            form_urlencoded::byte_serialize(query.as_bytes()).collect::<String>()
70        );
71        let response = self
72            .client
73            .get(&url)
74            .send()
75            .await
76            .context("DuckDuckGo request failed")?;
77        let body = response
78            .text()
79            .await
80            .context("failed reading DuckDuckGo response")?;
81
82        let mut results = Vec::new();
83        for (i, chunk) in body.split("class=\"result__a\"").skip(1).enumerate() {
84            if i >= max_results {
85                break;
86            }
87            let title = extract_between(chunk, ">", "</a>").unwrap_or_default();
88            let href = extract_between(chunk, "href=\"", "\"").unwrap_or_default();
89            let snippet = if let Some(snip_chunk) = chunk.split("class=\"result__snippet\"").nth(1)
90            {
91                extract_between(snip_chunk, ">", "</")
92                    .unwrap_or_default()
93                    .replace("&amp;", "&")
94                    .replace("&lt;", "<")
95                    .replace("&gt;", ">")
96                    .replace("&quot;", "\"")
97                    .replace("<b>", "")
98                    .replace("</b>", "")
99            } else {
100                String::new()
101            };
102            results.push(format!(
103                "{}. {}\n   {}\n   {}",
104                i + 1,
105                clean_html(title),
106                href,
107                clean_html(&snippet)
108            ));
109        }
110
111        if results.is_empty() {
112            Ok("no results found".to_string())
113        } else {
114            Ok(results.join("\n\n"))
115        }
116    }
117
118    async fn search_brave(
119        &self,
120        query: &str,
121        max_results: usize,
122        api_key: &str,
123    ) -> anyhow::Result<String> {
124        let url = format!(
125            "https://api.search.brave.com/res/v1/web/search?q={}&count={}",
126            form_urlencoded::byte_serialize(query.as_bytes()).collect::<String>(),
127            max_results.min(MAX_RESULTS_CAP)
128        );
129        let response = self
130            .client
131            .get(&url)
132            .header("X-Subscription-Token", api_key)
133            .header("Accept", "application/json")
134            .send()
135            .await
136            .context("Brave search request failed")?;
137
138        if !response.status().is_success() {
139            let status = response.status();
140            let body = response.text().await.unwrap_or_default();
141            anyhow::bail!("Brave API returned HTTP {status}: {body}");
142        }
143
144        let body: serde_json::Value = response
145            .json()
146            .await
147            .context("failed parsing Brave response")?;
148        let mut results = Vec::new();
149        if let Some(web) = body
150            .get("web")
151            .and_then(|w| w.get("results"))
152            .and_then(|r| r.as_array())
153        {
154            for (i, item) in web.iter().enumerate().take(max_results) {
155                let title = item.get("title").and_then(|v| v.as_str()).unwrap_or("");
156                let url = item.get("url").and_then(|v| v.as_str()).unwrap_or("");
157                let desc = item
158                    .get("description")
159                    .and_then(|v| v.as_str())
160                    .unwrap_or("");
161                results.push(format!("{}. {}\n   {}\n   {}", i + 1, title, url, desc));
162            }
163        }
164
165        if results.is_empty() {
166            Ok("no results found".to_string())
167        } else {
168            Ok(results.join("\n\n"))
169        }
170    }
171
172    async fn search_jina(
173        &self,
174        query: &str,
175        max_results: usize,
176        api_key: &str,
177    ) -> anyhow::Result<String> {
178        let url = format!(
179            "https://s.jina.ai/{}",
180            form_urlencoded::byte_serialize(query.as_bytes()).collect::<String>()
181        );
182        let response = self
183            .client
184            .get(&url)
185            .header("Authorization", format!("Bearer {api_key}"))
186            .header("Accept", "application/json")
187            .send()
188            .await
189            .context("Jina search request failed")?;
190
191        if !response.status().is_success() {
192            let status = response.status();
193            let body = response.text().await.unwrap_or_default();
194            anyhow::bail!("Jina API returned HTTP {status}: {body}");
195        }
196
197        let body: serde_json::Value = response
198            .json()
199            .await
200            .context("failed parsing Jina response")?;
201        let mut results = Vec::new();
202        if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
203            for (i, item) in data.iter().enumerate().take(max_results) {
204                let title = item.get("title").and_then(|v| v.as_str()).unwrap_or("");
205                let url = item.get("url").and_then(|v| v.as_str()).unwrap_or("");
206                let desc = item
207                    .get("description")
208                    .and_then(|v| v.as_str())
209                    .unwrap_or("");
210                results.push(format!("{}. {}\n   {}\n   {}", i + 1, title, url, desc));
211            }
212        }
213
214        if results.is_empty() {
215            Ok("no results found".to_string())
216        } else {
217            Ok(results.join("\n\n"))
218        }
219    }
220}
221
222fn extract_between<'a>(text: &'a str, start: &str, end: &str) -> Option<&'a str> {
223    let s = text.find(start)? + start.len();
224    let e = text[s..].find(end)? + s;
225    Some(&text[s..e])
226}
227
228fn clean_html(s: &str) -> String {
229    let mut out = String::with_capacity(s.len());
230    let mut in_tag = false;
231    for ch in s.chars() {
232        if ch == '<' {
233            in_tag = true;
234        } else if ch == '>' {
235            in_tag = false;
236        } else if !in_tag {
237            out.push(ch);
238        }
239    }
240    out.replace("&amp;", "&")
241        .replace("&lt;", "<")
242        .replace("&gt;", ">")
243        .replace("&quot;", "\"")
244}
245
246#[async_trait]
247impl Tool for WebSearchTool {
248    fn name(&self) -> &'static str {
249        "web_search"
250    }
251
252    fn description(&self) -> &'static str {
253        "Search the web using DuckDuckGo, Brave, or Jina and return a summary of results."
254    }
255
256    fn input_schema(&self) -> Option<serde_json::Value> {
257        Some(serde_json::json!({
258            "type": "object",
259            "properties": {
260                "query": {
261                    "type": "string",
262                    "description": "The search query"
263                }
264            },
265            "required": ["query"]
266        }))
267    }
268
269    async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
270        let req: WebSearchInput =
271            serde_json::from_str(input).context("web_search expects JSON: {\"query\": \"...\"}")?;
272
273        if req.query.trim().is_empty() {
274            return Err(anyhow!("query must not be empty"));
275        }
276
277        let max = req.max_results.clamp(1, MAX_RESULTS_CAP);
278        let provider = req.provider.as_deref().unwrap_or(&self.config.provider);
279
280        let brave_env_key = std::env::var("BRAVE_API_KEY").ok();
281        let jina_env_key = std::env::var("JINA_API_KEY").ok();
282
283        let output = match provider {
284            "brave" => {
285                let key = self
286                    .config
287                    .brave_api_key
288                    .as_deref()
289                    .or(brave_env_key.as_deref())
290                    .ok_or_else(|| {
291                        anyhow!("brave_api_key is required for Brave search provider")
292                    })?;
293                self.search_brave(&req.query, max, key).await?
294            }
295            "jina" => {
296                let key = self
297                    .config
298                    .jina_api_key
299                    .as_deref()
300                    .or(jina_env_key.as_deref())
301                    .ok_or_else(|| anyhow!("jina_api_key is required for Jina search provider"))?;
302                self.search_jina(&req.query, max, key).await?
303            }
304            _ => self.search_duckduckgo(&req.query, max).await?,
305        };
306
307        Ok(ToolResult { output })
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[tokio::test]
316    async fn web_search_rejects_empty_query() {
317        let tool = WebSearchTool::default();
318        let err = tool
319            .execute(r#"{"query": ""}"#, &ToolContext::new(".".to_string()))
320            .await
321            .expect_err("empty query should fail");
322        assert!(err.to_string().contains("query must not be empty"));
323    }
324
325    #[tokio::test]
326    async fn web_search_rejects_invalid_json() {
327        let tool = WebSearchTool::default();
328        let err = tool
329            .execute("not json", &ToolContext::new(".".to_string()))
330            .await
331            .expect_err("invalid JSON should fail");
332        assert!(err.to_string().contains("web_search expects JSON"));
333    }
334
335    #[tokio::test]
336    async fn web_search_brave_requires_api_key() {
337        let tool = WebSearchTool::new(WebSearchConfig {
338            provider: "brave".to_string(),
339            brave_api_key: None,
340            ..Default::default()
341        });
342        let err = tool
343            .execute(r#"{"query": "test"}"#, &ToolContext::new(".".to_string()))
344            .await
345            .expect_err("missing API key should fail");
346        assert!(err.to_string().contains("brave_api_key"));
347    }
348
349    #[test]
350    fn clean_html_strips_tags() {
351        assert_eq!(clean_html("<b>hello</b> world"), "hello world");
352        assert_eq!(clean_html("no tags"), "no tags");
353    }
354
355    #[test]
356    fn extract_between_works() {
357        assert_eq!(extract_between("foo=bar;baz", "=", ";"), Some("bar"));
358        assert_eq!(extract_between("nothing", "=", ";"), None);
359    }
360}