use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SearchResult {
pub title: String,
pub url: String,
pub snippet: String,
pub published_date: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SearchResponse {
pub query: String,
pub results: Vec<SearchResult>,
pub total_results: Option<u64>,
}
#[async_trait]
pub trait SearchProvider: Send + Sync {
async fn search(&self, query: &str, max_results: usize) -> Result<SearchResponse>;
fn provider_name(&self) -> &'static str;
}
#[derive(Clone)]
pub struct BraveSearchProvider {
client: reqwest::Client,
api_key: String,
}
impl BraveSearchProvider {
#[must_use]
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
}
}
#[must_use]
pub fn with_client(client: reqwest::Client, api_key: impl Into<String>) -> Self {
Self {
client,
api_key: api_key.into(),
}
}
}
mod brave_api {
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct BraveSearchResponse {
pub query: Option<BraveQuery>,
pub web: Option<BraveWebResults>,
}
#[derive(Debug, Deserialize)]
pub struct BraveQuery {
pub original: String,
}
#[derive(Debug, Deserialize)]
pub struct BraveWebResults {
pub results: Vec<BraveWebResult>,
}
#[derive(Debug, Deserialize)]
pub struct BraveWebResult {
pub title: String,
pub url: String,
pub description: Option<String>,
pub age: Option<String>,
}
}
#[async_trait]
impl SearchProvider for BraveSearchProvider {
async fn search(&self, query: &str, max_results: usize) -> Result<SearchResponse> {
let url = "https://api.search.brave.com/res/v1/web/search";
let response = self
.client
.get(url)
.header("X-Subscription-Token", &self.api_key)
.header("Accept", "application/json")
.query(&[
("q", query),
("count", &max_results.to_string()),
("text_decorations", "false"),
])
.send()
.await
.context("Failed to send request to Brave Search API")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Brave Search API error: {status} - {body}");
}
let brave_response: brave_api::BraveSearchResponse = response
.json()
.await
.context("Failed to parse Brave Search API response")?;
let results = brave_response
.web
.map(|web| {
web.results
.into_iter()
.map(|r| SearchResult {
title: r.title,
url: r.url,
snippet: r.description.unwrap_or_default(),
published_date: r.age,
})
.collect()
})
.unwrap_or_default();
let query_str = brave_response
.query
.map_or_else(|| query.to_string(), |q| q.original);
Ok(SearchResponse {
query: query_str,
results,
total_results: None,
})
}
fn provider_name(&self) -> &'static str {
"brave"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_result_serialization() {
let result = SearchResult {
title: "Test Title".into(),
url: "https://example.com".into(),
snippet: "Test snippet".into(),
published_date: Some("2024-01-01".into()),
};
let json = serde_json::to_string(&result).expect("serialize");
assert!(json.contains("Test Title"));
assert!(json.contains("example.com"));
}
#[test]
fn test_search_response_serialization() {
let response = SearchResponse {
query: "test query".into(),
results: vec![SearchResult {
title: "Result 1".into(),
url: "https://example.com/1".into(),
snippet: "First result".into(),
published_date: None,
}],
total_results: Some(100),
};
let json = serde_json::to_string(&response).expect("serialize");
assert!(json.contains("test query"));
assert!(json.contains("Result 1"));
}
#[test]
fn test_brave_provider_creation() {
let provider = BraveSearchProvider::new("test-api-key");
assert_eq!(provider.provider_name(), "brave");
}
#[test]
fn test_brave_provider_with_custom_client() {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("build client");
let provider = BraveSearchProvider::with_client(client, "test-api-key");
assert_eq!(provider.provider_name(), "brave");
}
}