use rucora_core::{
error::ToolError,
tool::{Tool, ToolCategory},
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
pub struct TavilyTool {
api_keys: Vec<String>,
}
impl TavilyTool {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_keys: vec![api_key.into()],
}
}
pub fn with_keys(api_keys: Vec<String>) -> Self {
if api_keys.is_empty() {
panic!("API Keys 不能为空");
}
Self { api_keys }
}
pub fn from_env() -> Result<Self, ToolError> {
let api_keys_str = std::env::var("TAVILY_API_KEYS")
.or_else(|_| std::env::var("TAVILY_API_KEY"))
.map_err(|_| {
ToolError::Message("缺少环境变量 TAVILY_API_KEYS 或 TAVILY_API_KEY".to_string())
})?;
let api_keys: Vec<String> = api_keys_str
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if api_keys.is_empty() {
return Err(ToolError::Message("API Keys 不能为空".to_string()));
}
Ok(Self { api_keys })
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TavilyArgs {
pub query: String,
#[serde(default = "default_max_results")]
pub max_results: usize,
#[serde(default = "default_true")]
pub include_answer: bool,
#[serde(default)]
pub include_raw_content: bool,
#[serde(default = "default_search_depth")]
pub search_depth: String,
}
fn default_max_results() -> usize {
5
}
fn default_true() -> bool {
true
}
fn default_search_depth() -> String {
"basic".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TavilySearchResult {
#[serde(skip_serializing_if = "Option::is_none")]
pub answer: Option<String>,
pub results: Vec<TavilyResult>,
pub query: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TavilyResult {
pub title: String,
pub url: String,
pub content: String,
pub score: f32,
}
#[async_trait]
impl Tool for TavilyTool {
fn name(&self) -> &str {
"tavily_search"
}
fn description(&self) -> Option<&str> {
Some("使用 Tavily AI 进行智能搜索,需要设置 TAVILY_API_KEY 环境变量")
}
fn categories(&self) -> &'static [ToolCategory] {
&[ToolCategory::Network]
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索关键词"
},
"max_results": {
"type": "integer",
"description": "搜索结果数量(默认 5,最大 15)"
},
"include_answer": {
"type": "boolean",
"description": "是否包含 AI 生成的答案(默认 true)"
},
"search_depth": {
"type": "string",
"description": "搜索深度:basic(基础), advanced(高级)",
"enum": ["basic", "advanced"]
}
},
"required": ["query"]
})
}
async fn call(&self, input: Value) -> Result<Value, ToolError> {
let args: TavilyArgs = serde_json::from_value(input)
.map_err(|e| ToolError::Message(format!("解析参数失败:{e}")))?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let idx = (now.subsec_nanos() as usize) % self.api_keys.len();
let api_key = self.api_keys[idx].clone();
let request_body = json!({
"query": args.query,
"api_key": api_key,
"max_results": args.max_results.min(15),
"include_answer": args.include_answer,
"include_raw_content": args.include_raw_content,
"search_depth": args.search_depth,
});
let client = reqwest::Client::new();
let response = client
.post("https://api.tavily.com/search")
.json(&request_body)
.send()
.await
.map_err(|e| ToolError::Message(format!("请求失败:{e}")))?;
if !response.status().is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "未知错误".to_string());
return Err(ToolError::Message(format!(
"Tavily API 错误:{error_text}"
)));
}
let search_result: Value = response
.json()
.await
.map_err(|e| ToolError::Message(format!("解析 JSON 失败:{e}")))?;
Ok(json!({
"success": true,
"answer": search_result.get("answer").cloned(),
"results": search_result.get("results").cloned().unwrap_or(json!([])),
"query": search_result.get("query").cloned().unwrap_or(json!(args.query))
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tavily_from_env() {
unsafe {
std::env::set_var("TAVILY_API_KEY", "test_key_1,test_key_2");
}
let result = TavilyTool::from_env();
assert!(result.is_ok());
let tool = result.unwrap();
assert_eq!(tool.api_keys.len(), 2);
}
#[test]
fn test_tavily_args_default() {
let args = TavilyArgs {
query: "test".to_string(),
max_results: default_max_results(),
include_answer: default_true(),
include_raw_content: false,
search_depth: default_search_depth(),
};
assert_eq!(args.max_results, 5);
assert!(args.include_answer);
assert_eq!(args.search_depth, "basic");
}
}