use anyhow::Result;
use async_trait::async_trait;
use reqwest::header::USER_AGENT;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::BTreeSet;
use super::{Tool, ToolOutput};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSearchInput {
pub query: String,
pub allowed_domains: Option<Vec<String>>,
pub blocked_domains: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub title: String,
pub url: String,
pub snippet: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSearchOutput {
pub query: String,
pub results: Vec<SearchResult>,
}
pub struct WebSearchTool {
client: reqwest::Client,
}
impl WebSearchTool {
pub fn new() -> Self {
Self {
client: reqwest::Client::new(),
}
}
}
impl Default for WebSearchTool {
fn default() -> Self {
Self::new()
}
}
fn strip_html_tags(html: &str) -> String {
let mut result = String::new();
let mut in_tag = false;
for ch in html.chars() {
if ch == '<' {
in_tag = true;
} else if ch == '>' {
in_tag = false;
} else if !in_tag {
result.push(ch);
}
}
result
.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace(""", "\"")
.replace("'", "'")
.replace("'", "'")
}
#[async_trait]
impl Tool for WebSearchTool {
fn name(&self) -> &str {
"WebSearch"
}
fn description(&self) -> &str {
"Search the web for current information and return cited results."
}
fn parameters_schema(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string", "minLength": 2 },
"allowed_domains": {
"type": "array",
"items": { "type": "string" }
},
"blocked_domains": {
"type": "array",
"items": { "type": "string" }
}
},
"required": ["query"],
"additionalProperties": false
})
}
async fn execute(&self, input: Value) -> Result<ToolOutput> {
let search_input: WebSearchInput = serde_json::from_value(input)?;
let search_url = reqwest::Url::parse_with_params(
"https://html.duckduckgo.com/html/",
&[("q", &search_input.query)],
)?;
let res = self
.client
.get(search_url)
.header(
USER_AGENT,
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 \
(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
)
.send()
.await?;
let html_content = res.text().await?;
let mut results = Vec::new();
let allowed_set: Option<BTreeSet<String>> = search_input
.allowed_domains
.map(|list| list.into_iter().map(|d| d.to_lowercase()).collect());
let blocked_set: Option<BTreeSet<String>> = search_input
.blocked_domains
.map(|list| list.into_iter().map(|d| d.to_lowercase()).collect());
let mut cursor = 0;
while let Some(start_idx) = html_content[cursor..].find("<a class=\"result__a\"") {
let absolute_start = cursor + start_idx;
let current_slice = &html_content[absolute_start..];
let href_tag = "href=\"";
if let Some(href_start) = current_slice.find(href_tag) {
let url_start = href_start + href_tag.len();
if let Some(url_end) = current_slice[url_start..].find('"') {
let mut url = current_slice[url_start..(url_start + url_end)].to_string();
if url.contains("uddg=") {
if let Some(uddg_idx) = url.find("uddg=") {
let encoded_url = &url[uddg_idx + 5..];
let decoded_url = urlencoding::decode(encoded_url).unwrap_or_default().into_owned();
url = decoded_url.split('&').next().unwrap_or(&decoded_url).to_string();
}
}
let closing_bracket = ">";
if let Some(bracket_idx) = current_slice[url_start + url_end..].find(closing_bracket) {
let title_start = url_start + url_end + bracket_idx + closing_bracket.len();
if let Some(a_close_idx) = current_slice[title_start..].find("</a>") {
let raw_title = ¤t_slice[title_start..(title_start + a_close_idx)];
let title = strip_html_tags(raw_title).trim().to_string();
let mut snippet = String::new();
if let Some(snippet_container_idx) = current_slice.find("result__snippet") {
let snippet_slice = ¤t_slice[snippet_container_idx..];
if let Some(snippet_content_start) = snippet_slice.find('>') {
let snippet_start = snippet_container_idx + snippet_content_start + 1;
if let Some(snippet_content_end) = current_slice[snippet_start..].find('<') {
let raw_snippet = ¤t_slice[snippet_start..(snippet_start + snippet_content_end)];
snippet = strip_html_tags(raw_snippet).trim().to_string();
}
}
}
let matches_allowed = allowed_set.as_ref().map_or(true, |set| {
set.iter().any(|domain| url.to_lowercase().contains(domain))
});
let matches_blocked = blocked_set.as_ref().map_or(false, |set| {
set.iter().any(|domain| url.to_lowercase().contains(domain))
});
if matches_allowed && !matches_blocked && !url.is_empty() && !title.is_empty() {
results.push(SearchResult { title, url, snippet });
}
}
}
}
}
if let Some(closing_a) = current_slice.find("</a>") {
cursor = absolute_start + closing_a + 4;
} else {
cursor = absolute_start + 20;
}
if results.len() >= 6 {
break;
}
}
let output = WebSearchOutput {
query: search_input.query,
results,
};
let serialized = serde_json::to_string_pretty(&output)?;
Ok(ToolOutput::success(serialized))
}
}