use rig::completion::ToolDefinition;
use rig::tool::Tool;
use schemars::{JsonSchema, schema_for};
use serde::Deserialize;
use std::time::Duration;
use tavily2::{SearchResponse, Tavily, TavilyError};
pub struct TavilyTool {
pub api_keys: Vec<String>,
}
impl TavilyTool {
pub fn new<S: Into<String>>(api_key: S) -> Self {
Self {
api_keys: vec![api_key.into()],
}
}
pub fn new_with_keys<S: Into<String>>(api_key: Vec<S>) -> Self {
if api_key.is_empty() {
panic!("Api key should be greater than 0");
}
Self {
api_keys: api_key.into_iter().map(|k| k.into()).collect(),
}
}
}
#[derive(Deserialize, JsonSchema, Debug, Clone)]
pub struct TavilyArgs {
pub query: String,
}
impl Tool for TavilyTool {
const NAME: &'static str = "Serpapi Tool";
type Error = TavilyError;
type Args = TavilyArgs;
type Output = SearchResponse;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_string(),
description: "使用Tavily进行联网搜索".to_string(),
parameters: serde_json::to_value(schema_for!(Self::Args)).unwrap(),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let tavily = Tavily::builder_with_keys(self.api_keys.clone())
.timeout(Duration::from_secs(60))
.max_retries(5)
.multi_retry(true)
.build()?;
let result = tavily.search(args.query).await?;
Ok(result)
}
}