use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::core::tools::{BaseTool, Tool, ToolError};
#[derive(Debug, Deserialize, JsonSchema)]
pub struct SearchInput {
pub query: String,
pub top_k: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct SearchOutput {
pub query: String,
pub results: Vec<SearchResult>,
pub total: usize,
pub abstract_text: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub title: String,
pub snippet: String,
pub url: String,
}
pub struct DuckDuckGoSearchTool {
client: reqwest::Client,
}
impl DuckDuckGoSearchTool {
pub fn new() -> Self {
Self {
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(15))
.user_agent("LangChainRust/0.1 (Search Tool)")
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
}
}
}
impl Default for DuckDuckGoSearchTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for DuckDuckGoSearchTool {
type Input = SearchInput;
type Output = SearchOutput;
async fn invoke(&self, input: Self::Input) -> Result<Self::Output, ToolError> {
if input.query.trim().is_empty() {
return Err(ToolError::InvalidInput("搜索查询不能为空".to_string()));
}
let top_k = input.top_k.unwrap_or(5);
let url = format!(
"https://api.duckduckgo.com/?q={}&format=json&no_html=1",
urlencoding(&input.query)
);
let response = self.client.get(&url)
.send()
.await
.map_err(|e| ToolError::ExecutionFailed(format!("搜索请求失败: {}", e)))?;
let body: serde_json::Value = response.json().await
.map_err(|e| ToolError::ExecutionFailed(format!("解析搜索结果失败: {}", e)))?;
let abstract_text = body["AbstractText"].as_str()
.filter(|s| !s.is_empty())
.map(|s| s.to_string());
let mut results = Vec::new();
if let Some(topics) = body["RelatedTopics"].as_array() {
for topic in topics.iter() {
if results.len() >= top_k {
break;
}
if let Some(text) = topic["Text"].as_str() {
let title = topic["FirstURL"].as_str()
.map(|u| u.rsplit('/').next().unwrap_or(u).replace('_', " "))
.unwrap_or_default();
let url = topic["FirstURL"].as_str().unwrap_or("");
results.push(SearchResult {
title,
snippet: text.to_string(),
url: url.to_string(),
});
}
if let Some(nested) = topic["Topics"].as_array() {
for nt in nested.iter() {
if results.len() >= top_k {
break;
}
if let Some(text) = nt["Text"].as_str() {
let title = nt["FirstURL"].as_str()
.map(|u| u.rsplit('/').next().unwrap_or(u).replace('_', " "))
.unwrap_or_default();
let url = nt["FirstURL"].as_str().unwrap_or("");
results.push(SearchResult {
title,
snippet: text.to_string(),
url: url.to_string(),
});
}
}
}
}
}
if results.is_empty() {
if let Some(ref abstract_text) = abstract_text {
let source_url = body["AbstractURL"].as_str().unwrap_or("");
results.push(SearchResult {
title: input.query.clone(),
snippet: abstract_text.clone(),
url: source_url.to_string(),
});
}
}
Ok(SearchOutput {
query: input.query,
total: results.len(),
results,
abstract_text,
})
}
}
fn urlencoding(s: &str) -> String {
url_encode(s)
}
fn url_encode(s: &str) -> String {
s.split(' ').collect::<Vec<_>>().join("%20")
.replace('?', "%3F")
.replace('&', "%26")
.replace('=', "%3D")
.replace('#', "%23")
}
#[async_trait]
impl BaseTool for DuckDuckGoSearchTool {
fn name(&self) -> &str {
"web_search"
}
fn description(&self) -> &str {
"网页搜索工具。使用 DuckDuckGo 搜索引擎搜索网络信息,无需 API Key。
参数:
- query: 搜索关键词
- top_k: 返回结果数量(默认 5)
示例:
- {\"query\": \"Rust programming language\", \"top_k\": 3}"
}
async fn run(&self, input: String) -> Result<String, ToolError> {
let parsed: SearchInput = serde_json::from_str(&input)
.map_err(|e| ToolError::InvalidInput(format!("JSON 解析失败: {}", e)))?;
let output = self.invoke(parsed).await?;
let mut text = format!("搜索结果 (查询: {})\n\n", output.query);
if let Some(ref abstract_text) = output.abstract_text {
text.push_str(&format!("摘要: {}\n\n", abstract_text));
}
for (i, result) in output.results.iter().enumerate() {
text.push_str(&format!("{}. {}\n", i + 1, result.title));
text.push_str(&format!(" {}\n", result.snippet));
text.push_str(&format!(" URL: {}\n\n", result.url));
}
if output.results.is_empty() {
text.push_str("未找到相关结果");
} else {
text.push_str(&format!("共 {} 条结果", output.total));
}
Ok(text)
}
fn args_schema(&self) -> Option<serde_json::Value> {
use schemars::schema_for;
serde_json::to_value(schema_for!(SearchInput)).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_tool_properties() {
let tool = DuckDuckGoSearchTool::new();
assert_eq!(tool.name(), "web_search");
assert!(tool.description().contains("DuckDuckGo"));
assert!(BaseTool::args_schema(&tool).is_some());
}
#[tokio::test]
async fn test_search_empty_query() {
let tool = DuckDuckGoSearchTool::new();
let result = tool.run(r#"{"query": ""}"#.to_string()).await;
assert!(result.is_err());
}
#[tokio::test]
#[ignore = "需要网络连接"]
async fn test_search_real() {
let tool = DuckDuckGoSearchTool::new();
let result = tool.invoke(SearchInput {
query: "Rust programming".to_string(),
top_k: Some(3),
}).await;
match &result {
Ok(output) => {
println!("搜索到了 {} 条结果", output.total);
if output.total > 0 {
assert!(!output.results[0].title.is_empty());
}
}
Err(e) => {
eprintln!("搜索可能被限制: {}", e);
}
}
}
}