use crate::tool::Tool;
use async_trait::async_trait;
use serde_json::json;
use std::time::Duration;
const DEFAULT_MAX_RESULTS: usize = 5;
const DEFAULT_TIMEOUT_SECS: u64 = 30;
pub struct WebSearchTool {
provider: String,
brave_api_key: Option<String>,
max_results: usize,
timeout_secs: u64,
}
impl WebSearchTool {
pub fn new(provider: impl Into<String>) -> Self {
Self {
provider: provider.into().to_lowercase(),
brave_api_key: None,
max_results: DEFAULT_MAX_RESULTS,
timeout_secs: DEFAULT_TIMEOUT_SECS,
}
}
pub fn with_brave_key(mut self, key: impl Into<String>) -> Self {
self.brave_api_key = Some(key.into());
self
}
pub fn with_max_results(mut self, max: usize) -> Self {
self.max_results = max.clamp(1, 10);
self
}
pub fn with_timeout(mut self, secs: u64) -> Self {
self.timeout_secs = secs.max(1);
self
}
async fn search_duckduckgo(&self, query: &str) -> anyhow::Result<Vec<SearchResult>> {
let encoded_query = urlencoding::encode(query);
let search_url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs))
.user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36")
.build()?;
let response = client.get(&search_url).send().await?;
if !response.status().is_success() {
anyhow::bail!("DuckDuckGo search failed: {}", response.status());
}
let html = response.text().await?;
self.parse_duckduckgo_results(&html)
}
fn parse_duckduckgo_results(&self, html: &str) -> anyhow::Result<Vec<SearchResult>> {
let mut results = Vec::new();
let result_regex =
regex::Regex::new(r#"class="result__a"[^>]*href="([^"]+)"[^>]*>([^<]+)"#)?;
for cap in result_regex.captures_iter(html) {
if results.len() >= self.max_results {
break;
}
let url = cap.get(1).map(|m| m.as_str()).unwrap_or("");
let title = cap.get(2).map(|m| m.as_str()).unwrap_or("");
let url = if url.starts_with("//duckduckgo.com/l/?") {
url.split("uddg=")
.nth(1)
.and_then(|u| urlencoding::decode(u).ok())
.map(|s| s.to_string())
.unwrap_or_else(|| url.to_string())
} else {
url.to_string()
};
let mut decoded_title = String::new();
html_escape::decode_html_entities_to_string(title, &mut decoded_title);
results.push(SearchResult {
title: decoded_title,
url,
snippet: String::new(),
});
}
Ok(results)
}
async fn search_brave(&self, query: &str) -> anyhow::Result<Vec<SearchResult>> {
let api_key = self
.brave_api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Brave API key required"))?;
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs))
.build()?;
let response = client
.get("https://api.search.brave.com/res/v1/web/search")
.header("Accept", "application/json")
.header("X-Subscription-Token", api_key)
.query(&[("q", query), ("count", &self.max_results.to_string())])
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Brave search failed ({}): {}", status, body);
}
let data: serde_json::Value = response.json().await?;
let mut results = Vec::new();
if let Some(web) = data.get("web") {
if let Some(pages) = web.get("results").and_then(|r| r.as_array()) {
for page in pages.iter().take(self.max_results) {
results.push(SearchResult {
title: page
.get("title")
.and_then(|t| t.as_str())
.unwrap_or("")
.to_string(),
url: page
.get("url")
.and_then(|u| u.as_str())
.unwrap_or("")
.to_string(),
snippet: page
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string(),
});
}
}
}
Ok(results)
}
}
#[derive(Debug, Clone)]
struct SearchResult {
title: String,
url: String,
snippet: String,
}
#[async_trait]
impl Tool for WebSearchTool {
fn name(&self) -> &str {
"web_search"
}
fn description(&self) -> &str {
"Search the web using DuckDuckGo (free) or Brave (requires API key)"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results (1-10)",
"minimum": 1,
"maximum": 10,
"default": 5
}
},
"required": ["query"]
})
}
fn requires_network(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
let query = args
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter"))?;
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.map(|n| n as usize)
.unwrap_or(DEFAULT_MAX_RESULTS)
.clamp(1, 10);
let results = match self.provider.as_str() {
"brave" => self.search_brave(query).await?,
"duckduckgo" => self.search_duckduckgo(query).await?,
_ => self.search_duckduckgo(query).await?,
};
let results_json: Vec<serde_json::Value> = results
.into_iter()
.take(max_results)
.map(|r| {
json!({
"title": r.title,
"url": r.url,
"snippet": r.snippet
})
})
.collect();
Ok(json!({
"success": true,
"query": query,
"provider": self.provider,
"results": results_json,
"count": results_json.len()
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_web_search_duckduckgo() {
let tool = WebSearchTool::new("duckduckgo");
let result = tool
.execute(json!({
"query": "Rust programming language",
"max_results": 3
}))
.await;
if let Ok(response) = result {
assert_eq!(response["success"], true);
assert!(
response.get("results").is_some(),
"response must contain 'results'"
);
let _ = response["count"].as_u64();
}
}
#[test]
fn test_web_search_schema() {
let tool = WebSearchTool::new("duckduckgo");
let schema = tool.parameters_schema();
assert!(schema["properties"]["query"].is_object());
assert!(schema["properties"]["max_results"].is_object());
}
}