use super::{SearchError, SearchProvider, SearchResult, SearchResults};
use crate::config::SearchOptions;
use async_trait::async_trait;
const DEFAULT_API_URL: &str = "https://api.tavily.com/search";
#[derive(Debug, Clone)]
pub struct TavilyProvider {
api_key: String,
api_url: Option<String>,
search_depth: String,
}
impl TavilyProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
api_url: None,
search_depth: "basic".to_string(),
}
}
pub fn with_api_url(mut self, url: impl Into<String>) -> Self {
self.api_url = Some(url.into());
self
}
pub fn with_advanced_search(mut self) -> Self {
self.search_depth = "advanced".to_string();
self
}
fn endpoint(&self) -> &str {
self.api_url.as_deref().unwrap_or(DEFAULT_API_URL)
}
}
#[async_trait]
impl SearchProvider for TavilyProvider {
async fn search(
&self,
query: &str,
options: &SearchOptions,
client: &reqwest::Client,
) -> Result<SearchResults, SearchError> {
let mut body = serde_json::json!({
"api_key": &self.api_key,
"query": query,
"search_depth": &self.search_depth
});
if let Some(limit) = options.limit {
body["max_results"] = serde_json::json!(limit.min(10));
}
if let Some(ref sites) = options.site_filter {
body["include_domains"] = serde_json::json!(sites);
}
if let Some(ref exclude) = options.exclude_domains {
body["exclude_domains"] = serde_json::json!(exclude);
}
let response = client
.post(self.endpoint())
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| SearchError::RequestFailed(e.to_string()))?;
let status = response.status();
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
return Err(SearchError::AuthenticationFailed);
}
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Err(SearchError::RateLimited);
}
if !status.is_success() {
return Err(SearchError::ProviderError(format!(
"HTTP {} from Tavily API",
status
)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| SearchError::ProviderError(format!("Failed to parse response: {}", e)))?;
let mut results = SearchResults::new(query);
if let Some(items) = json.get("results").and_then(|v| v.as_array()) {
for (i, item) in items.iter().enumerate() {
let title = item
.get("title")
.and_then(|v| v.as_str())
.unwrap_or_default();
let url = item.get("url").and_then(|v| v.as_str()).unwrap_or_default();
if url.is_empty() {
continue;
}
let mut result = SearchResult::new(title, url, i + 1);
if let Some(snippet) = item.get("content").and_then(|v| v.as_str()) {
result = result.with_snippet(snippet);
}
if let Some(score) = item.get("score").and_then(|v| v.as_f64()) {
result = result.with_score(score as f32);
}
if let Some(date) = item.get("published_date").and_then(|v| v.as_str()) {
result = result.with_date(date);
}
results.push(result);
}
}
results.metadata = Some(json);
Ok(results)
}
fn provider_name(&self) -> &'static str {
"tavily"
}
fn is_configured(&self) -> bool {
!self.api_key.is_empty() || self.api_url.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tavily_provider_new() {
let provider = TavilyProvider::new("test-key");
assert_eq!(provider.endpoint(), DEFAULT_API_URL);
assert_eq!(provider.search_depth, "basic");
assert!(provider.is_configured());
}
#[test]
fn test_tavily_provider_advanced() {
let provider = TavilyProvider::new("test-key").with_advanced_search();
assert_eq!(provider.search_depth, "advanced");
}
#[test]
fn test_tavily_provider_custom_url() {
let provider =
TavilyProvider::new("test-key").with_api_url("https://custom.api.com/search");
assert_eq!(provider.endpoint(), "https://custom.api.com/search");
}
}