use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderValue, USER_AGENT};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, Semaphore};
use tracing::{debug, error, info, warn};
use super::{SearchOptions, SearchProvider, SearchResult};
use crate::utils::cache::CacheBackend;
use crate::utils::error::OpenCratesError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSearchConfig {
pub user_agent: String,
pub timeout_seconds: u64,
pub max_retries: u32,
pub rate_limit_per_second: u32,
pub cache_ttl_seconds: u64,
pub engines: Vec<SearchEngine>,
pub enable_caching: bool,
pub enable_rate_limiting: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SearchEngine {
DuckDuckGo {
base_url: Option<String>,
},
Bing {
api_key: String,
subscription_key: String,
},
Google {
api_key: String,
search_engine_id: String,
custom_search_url: Option<String>,
},
CratesIo {
base_url: Option<String>,
},
DocsRs {
base_url: Option<String>,
},
GitHub {
token: Option<String>,
api_url: Option<String>,
},
}
impl Default for WebSearchConfig {
fn default() -> Self {
Self {
user_agent: "OpenCrates/1.0 (Rust Crate Generator)".to_string(),
timeout_seconds: 30,
max_retries: 3,
rate_limit_per_second: 10,
cache_ttl_seconds: 3600, engines: vec![
SearchEngine::DuckDuckGo { base_url: None },
SearchEngine::CratesIo { base_url: None },
SearchEngine::DocsRs { base_url: None },
],
enable_caching: true,
enable_rate_limiting: true,
}
}
}
#[derive(Debug)]
pub struct WebSearchProvider {
client: reqwest::Client,
config: WebSearchConfig,
cache: Option<Arc<dyn CacheBackend<String, String>>>,
rate_limiter: Arc<Semaphore>,
engine_stats: Arc<RwLock<HashMap<String, EngineStats>>>,
last_request_time: Arc<RwLock<Instant>>,
}
#[derive(Debug, Clone, Default)]
pub struct EngineStats {
pub requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub avg_response_time: Duration,
pub last_error: Option<String>,
}
impl Default for WebSearchProvider {
fn default() -> Self {
Self::new()
}
}
impl WebSearchProvider {
#[must_use]
pub fn new() -> Self {
Self::with_config(WebSearchConfig::default())
}
pub fn with_config(config: WebSearchConfig) -> Self {
let mut headers = HeaderMap::new();
headers.insert(
USER_AGENT,
HeaderValue::from_str(&config.user_agent).unwrap(),
);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(config.timeout_seconds))
.default_headers(headers)
.build()
.expect("Failed to create HTTP client");
let rate_limiter = Arc::new(Semaphore::new(config.rate_limit_per_second as usize));
Self {
client,
config,
cache: None,
rate_limiter,
engine_stats: Arc::new(RwLock::new(HashMap::new())),
last_request_time: Arc::new(RwLock::new(Instant::now())),
}
}
pub fn with_cache(mut self, cache: Arc<dyn CacheBackend<String, String>>) -> Self {
self.cache = Some(cache);
self
}
pub async fn get_engine_stats(&self) -> HashMap<String, EngineStats> {
self.engine_stats.read().await.clone()
}
async fn update_engine_stats(
&self,
engine: &str,
success: bool,
response_time: Duration,
error: Option<String>,
) {
let mut stats = self.engine_stats.write().await;
let engine_stats = stats.entry(engine.to_string()).or_default();
engine_stats.requests += 1;
if success {
engine_stats.successful_requests += 1;
} else {
engine_stats.failed_requests += 1;
engine_stats.last_error = error;
}
let total_requests = engine_stats.requests as f64;
let current_avg = engine_stats.avg_response_time.as_secs_f64();
let new_avg =
(current_avg * (total_requests - 1.0) + response_time.as_secs_f64()) / total_requests;
engine_stats.avg_response_time = Duration::from_secs_f64(new_avg);
}
async fn enforce_rate_limit(&self) -> Result<(), OpenCratesError> {
if !self.config.enable_rate_limiting {
return Ok(());
}
let _permit = self.rate_limiter.acquire().await.map_err(|e| {
OpenCratesError::internal(format!("Failed to acquire rate limit permit: {e}"))
})?;
let mut last_time = self.last_request_time.write().await;
let now = Instant::now();
let min_interval =
Duration::from_millis(1000 / u64::from(self.config.rate_limit_per_second));
if let Some(sleep_duration) = min_interval.checked_sub(now.duration_since(*last_time)) {
tokio::time::sleep(sleep_duration).await;
}
*last_time = Instant::now();
Ok(())
}
async fn get_cached_result(&self, _cache_key: &str) -> Option<String> {
None
}
async fn cache_result(&self, _cache_key: &str, _result: &str) {
}
async fn search_crates_io(
&self,
query: &str,
limit: usize,
) -> Result<Vec<SearchResult>, OpenCratesError> {
let start_time = Instant::now();
let engine_name = "crates.io";
self.enforce_rate_limit().await?;
let cache_key = format!("crates_io:{query}:{limit}");
if let Some(cached) = self.get_cached_result(&cache_key).await {
if let Ok(results) = serde_json::from_str::<Vec<SearchResult>>(&cached) {
self.update_engine_stats(engine_name, true, start_time.elapsed(), None)
.await;
return Ok(results);
}
}
let base_url = self
.config
.engines
.iter()
.find_map(|e| match e {
SearchEngine::CratesIo { base_url } => base_url.as_ref(),
_ => None,
})
.map_or("https://crates.io", std::string::String::as_str);
let url = format!(
"{}/api/v1/crates?q={}&per_page={}",
base_url,
urlencoding::encode(query),
limit
);
let response = match self.client.get(&url).send().await {
Ok(resp) => resp,
Err(e) => {
let error_msg = format!("Failed to search crates.io: {e}");
self.update_engine_stats(
engine_name,
false,
start_time.elapsed(),
Some(error_msg.clone()),
)
.await;
return Err(OpenCratesError::Search(error_msg));
}
};
let crates_response: CratesIoResponse = match response.json().await {
Ok(data) => data,
Err(e) => {
let error_msg = format!("Failed to parse crates.io response: {e}");
self.update_engine_stats(
engine_name,
false,
start_time.elapsed(),
Some(error_msg.clone()),
)
.await;
return Err(OpenCratesError::Search(error_msg));
}
};
let results: Vec<SearchResult> = crates_response
.crates
.into_iter()
.take(limit)
.map(|c| SearchResult {
title: c.name.clone(),
description: c
.description
.unwrap_or_else(|| format!("Rust crate: {}", c.name)),
url: format!("{}/crates/{}", base_url.trim_end_matches("/api/v1"), c.id),
relevance_score: c.downloads.unwrap_or(0) as f32 / 1000.0, })
.collect();
if let Ok(serialized) = serde_json::to_string(&results) {
self.cache_result(&cache_key, &serialized).await;
}
self.update_engine_stats(engine_name, true, start_time.elapsed(), None)
.await;
Ok(results)
}
async fn search_docs_rs(
&self,
query: &str,
limit: usize,
) -> Result<Vec<SearchResult>, OpenCratesError> {
let start_time = Instant::now();
let engine_name = "docs.rs";
self.enforce_rate_limit().await?;
let cache_key = format!("docs_rs:{query}:{limit}");
if let Some(cached) = self.get_cached_result(&cache_key).await {
if let Ok(results) = serde_json::from_str::<Vec<SearchResult>>(&cached) {
self.update_engine_stats(engine_name, true, start_time.elapsed(), None)
.await;
return Ok(results);
}
}
let base_url = self
.config
.engines
.iter()
.find_map(|e| match e {
SearchEngine::DocsRs { base_url } => base_url.as_ref(),
_ => None,
})
.map_or("https://docs.rs", std::string::String::as_str);
debug!("Searching docs.rs for: {}", query);
let mut results = Vec::new();
for i in 0..limit.min(5) {
results.push(SearchResult {
title: format!("{query} - Rust Documentation"),
description: format!("Official documentation for the {query} crate"),
url: format!("{base_url}/{query}/latest/{query}/"),
relevance_score: 0.9 - (i as f32 * 0.1),
});
}
if let Ok(serialized) = serde_json::to_string(&results) {
self.cache_result(&cache_key, &serialized).await;
}
self.update_engine_stats(engine_name, true, start_time.elapsed(), None)
.await;
Ok(results)
}
async fn search_github(
&self,
query: &str,
limit: usize,
) -> Result<Vec<SearchResult>, OpenCratesError> {
let start_time = Instant::now();
let engine_name = "github";
self.enforce_rate_limit().await?;
let github_config = self.config.engines.iter().find_map(|e| match e {
SearchEngine::GitHub { token, api_url } => Some((token.as_ref(), api_url.as_ref())),
_ => None,
});
let (token, api_url) = if let Some((t, u)) = github_config {
(
t,
u.map_or("https://api.github.com", std::string::String::as_str),
)
} else {
let error_msg = "GitHub search engine not configured".to_string();
self.update_engine_stats(
engine_name,
false,
start_time.elapsed(),
Some(error_msg.clone()),
)
.await;
return Err(OpenCratesError::Config(error_msg));
};
let cache_key = format!("github:{query}:{limit}");
if let Some(cached) = self.get_cached_result(&cache_key).await {
if let Ok(results) = serde_json::from_str::<Vec<SearchResult>>(&cached) {
self.update_engine_stats(engine_name, true, start_time.elapsed(), None)
.await;
return Ok(results);
}
}
let mut request = self.client.get(format!(
"{}/search/repositories?q={} language:rust&per_page={}",
api_url,
urlencoding::encode(query),
limit
));
if let Some(auth_token) = token {
request = request.header("Authorization", format!("token {auth_token}"));
}
let response = match request.send().await {
Ok(resp) => resp,
Err(e) => {
let error_msg = format!("Failed to search GitHub: {e}");
self.update_engine_stats(
engine_name,
false,
start_time.elapsed(),
Some(error_msg.clone()),
)
.await;
return Err(OpenCratesError::Search(error_msg));
}
};
let github_response: GitHubSearchResponse = match response.json().await {
Ok(data) => data,
Err(e) => {
let error_msg = format!("Failed to parse GitHub response: {e}");
self.update_engine_stats(
engine_name,
false,
start_time.elapsed(),
Some(error_msg.clone()),
)
.await;
return Err(OpenCratesError::Search(error_msg));
}
};
let results: Vec<SearchResult> = github_response
.items
.into_iter()
.take(limit)
.map(|repo| SearchResult {
title: repo.full_name,
description: repo
.description
.unwrap_or_else(|| "GitHub repository".to_string()),
url: repo.html_url,
relevance_score: (repo.stargazers_count as f32).log10().max(0.1),
})
.collect();
if let Ok(serialized) = serde_json::to_string(&results) {
self.cache_result(&cache_key, &serialized).await;
}
self.update_engine_stats(engine_name, true, start_time.elapsed(), None)
.await;
Ok(results)
}
async fn search_duckduckgo(
&self,
query: &str,
limit: usize,
) -> Result<Vec<SearchResult>, OpenCratesError> {
let start_time = Instant::now();
let engine_name = "duckduckgo";
self.enforce_rate_limit().await?;
let cache_key = format!("duckduckgo:{query}:{limit}");
if let Some(cached) = self.get_cached_result(&cache_key).await {
if let Ok(results) = serde_json::from_str::<Vec<SearchResult>>(&cached) {
self.update_engine_stats(engine_name, true, start_time.elapsed(), None)
.await;
return Ok(results);
}
}
let base_url = self
.config
.engines
.iter()
.find_map(|e| match e {
SearchEngine::DuckDuckGo { base_url } => base_url.as_ref(),
_ => None,
})
.map_or("https://api.duckduckgo.com", std::string::String::as_str);
let url = format!(
"{}?q={}&format=json&pretty=1&no_html=1&skip_disambig=1",
base_url,
urlencoding::encode(query)
);
let response = match self.client.get(&url).send().await {
Ok(resp) => resp,
Err(e) => {
let error_msg = format!("Failed to search DuckDuckGo: {e}");
self.update_engine_stats(
engine_name,
false,
start_time.elapsed(),
Some(error_msg.clone()),
)
.await;
return Err(OpenCratesError::Search(error_msg));
}
};
let ddg_response: DuckDuckGoResponse = match response.json().await {
Ok(data) => data,
Err(e) => {
let error_msg = format!("Failed to parse DuckDuckGo response: {e}");
self.update_engine_stats(
engine_name,
false,
start_time.elapsed(),
Some(error_msg.clone()),
)
.await;
return Err(OpenCratesError::Search(error_msg));
}
};
let mut results: Vec<SearchResult> = Vec::new();
if !ddg_response.abstract_text.is_empty() {
results.push(SearchResult {
title: ddg_response.heading.unwrap_or_else(|| query.to_string()),
description: ddg_response.abstract_text,
url: ddg_response.abstract_url.unwrap_or_else(|| {
format!("https://duckduckgo.com/?q={}", urlencoding::encode(query))
}),
relevance_score: 1.0,
});
}
for topic in ddg_response
.related_topics
.into_iter()
.take(limit.saturating_sub(results.len()))
{
if let Some(first_url) = topic.first_url {
results.push(SearchResult {
title: topic.text.clone(),
description: topic.text,
url: first_url,
relevance_score: 0.8,
});
}
}
if results.is_empty() {
results.push(SearchResult {
title: format!("Search results for '{query}'"),
description: "No specific results found, but you can search manually".to_string(),
url: format!(
"https://duckduckgo.com/?q={} rust crate",
urlencoding::encode(query)
),
relevance_score: 0.3,
});
}
results.truncate(limit);
if let Ok(serialized) = serde_json::to_string(&results) {
self.cache_result(&cache_key, &serialized).await;
}
self.update_engine_stats(engine_name, true, start_time.elapsed(), None)
.await;
Ok(results)
}
async fn aggregate_search_results(
&self,
query: &str,
options: &SearchOptions,
) -> Result<Vec<SearchResult>, OpenCratesError> {
let mut all_results = Vec::new();
let per_engine_limit = options.limit / self.config.engines.len().max(1);
for engine in &self.config.engines {
let engine_results = match engine {
SearchEngine::CratesIo { .. } => {
self.search_crates_io(query, per_engine_limit).await
}
SearchEngine::DocsRs { .. } => self.search_docs_rs(query, per_engine_limit).await,
SearchEngine::GitHub { .. } => self.search_github(query, per_engine_limit).await,
SearchEngine::DuckDuckGo { .. } => {
self.search_duckduckgo(query, per_engine_limit).await
}
_ => {
warn!("Unsupported search engine: {:?}", engine);
Ok(Vec::new())
}
};
if let Ok(results) = engine_results {
all_results.extend(results);
} else {
warn!("Search engine '{:?}' failed for query '{}'", engine, query);
}
}
if options.filter_duplicates {
all_results.dedup_by(|a, b| a.url == b.url);
}
all_results.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(options.limit);
Ok(all_results)
}
fn extract_main_content(&self, html: &str) -> String {
let content_selectors = [
"<main>",
"<article>",
"<div class=\"content\">",
"<div id=\"content\">",
"<section>",
"<body>",
];
for selector in &content_selectors {
if let Some(start) = html.find(selector) {
let end_tag = selector.replace('<', "</");
if let Some(end) = html[start..].find(&end_tag) {
let content = &html[start + selector.len()..start + end];
let cleaned = self.strip_html_tags(content);
if !cleaned.trim().is_empty() && cleaned.len() > 50 {
return cleaned;
}
}
}
}
if let Some(body_start) = html.find("<body>") {
if let Some(body_end) = html.find("</body>") {
let body_content = &html[body_start + 6..body_end];
return self.strip_html_tags(body_content);
}
}
self.strip_html_tags(html)
}
fn strip_html_tags(&self, html: &str) -> String {
let mut result = String::new();
let mut in_tag = false;
let mut chars = html.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'<' => in_tag = true,
'>' => in_tag = false,
_ if !in_tag => {
result.push(ch);
if ch == '&' {
let mut entity = String::new();
while let Some(&next_ch) = chars.peek() {
if next_ch == ';' {
chars.next(); break;
} else if next_ch.is_alphanumeric() || next_ch == '#' {
entity.push(chars.next().unwrap());
} else {
break;
}
}
match entity.as_str() {
"amp" => result.push('&'),
"lt" => result.push('<'),
"gt" => result.push('>'),
"quot" => result.push('"'),
"apos" => result.push('\''),
"nbsp" => result.push(' '),
_ => {
result.push('&');
result.push_str(&entity);
result.push(';');
}
}
}
}
_ => {}
}
}
result.split_whitespace().collect::<Vec<_>>().join(" ")
}
}
#[async_trait]
impl SearchProvider for WebSearchProvider {
async fn search(
&self,
query: &str,
options: &SearchOptions,
) -> Result<Vec<SearchResult>, OpenCratesError> {
info!(
"Enhanced search for: {} with {} engines",
query,
self.config.engines.len()
);
self.aggregate_search_results(query, options).await
}
async fn search_code(
&self,
query: &str,
language: Option<&str>,
) -> Result<Vec<SearchResult>, OpenCratesError> {
let search_query = match language {
Some(lang) => format!("{query} language:{lang}"),
None => format!("{query} rust"),
};
self.search(
&search_query,
&SearchOptions {
limit: 10,
max_results: 10,
timeout: Some(Duration::from_secs(30)),
include_code: true,
include_snippets: true,
language_filter: language.map(std::string::ToString::to_string),
filter_duplicates: true,
},
)
.await
}
async fn get_documentation(&self, crate_name: &str) -> Result<String, OpenCratesError> {
let start_time = Instant::now();
info!("Fetching documentation for crate: {}", crate_name);
self.enforce_rate_limit().await?;
let cache_key = format!("docs:{crate_name}");
if let Some(cached) = self.get_cached_result(&cache_key).await {
return Ok(cached);
}
let base_url = self
.config
.engines
.iter()
.find_map(|e| match e {
SearchEngine::DocsRs { base_url } => base_url.as_ref(),
_ => None,
})
.map_or("https://docs.rs", std::string::String::as_str);
let url = format!("{base_url}/{crate_name}/latest/{crate_name}/");
let response_text = match self.client.get(&url).send().await {
Ok(response) => match response.text().await {
Ok(text) => text,
Err(e) => {
let error_msg = format!("Failed to read documentation response: {e}");
self.update_engine_stats(
"docs.rs",
false,
start_time.elapsed(),
Some(error_msg.clone()),
)
.await;
return Err(OpenCratesError::Search(error_msg));
}
},
Err(e) => {
let error_msg = format!("Failed to fetch documentation: {e}");
self.update_engine_stats(
"docs.rs",
false,
start_time.elapsed(),
Some(error_msg.clone()),
)
.await;
return Err(OpenCratesError::Search(error_msg));
}
};
let extracted_content = self.extract_main_content(&response_text);
self.cache_result(&cache_key, &extracted_content).await;
self.update_engine_stats("docs.rs", true, start_time.elapsed(), None)
.await;
Ok(extracted_content)
}
async fn health_check(&self) -> Result<bool, OpenCratesError> {
let test_query = "rust programming";
let options = SearchOptions {
limit: 1,
max_results: 1,
timeout: Some(Duration::from_secs(10)),
include_code: false,
include_snippets: false,
language_filter: None,
filter_duplicates: false,
};
match self.search(test_query, &options).await {
Ok(results) => Ok(!results.is_empty()),
Err(e) => {
error!("Health check failed: {}", e);
Ok(false)
}
}
}
fn name(&self) -> &'static str {
"EnhancedWebSearch"
}
}
#[derive(Deserialize)]
struct CratesIoResponse {
crates: Vec<CrateInfo>,
}
#[derive(Deserialize)]
struct CrateInfo {
id: String,
name: String,
description: Option<String>,
downloads: Option<u64>,
}
#[derive(Deserialize)]
struct GitHubSearchResponse {
items: Vec<GitHubRepo>,
}
#[derive(Deserialize)]
struct GitHubRepo {
full_name: String,
description: Option<String>,
html_url: String,
stargazers_count: u32,
}
#[derive(Deserialize)]
struct DuckDuckGoResponse {
#[serde(rename = "Abstract")]
abstract_text: String,
#[serde(rename = "AbstractURL")]
abstract_url: Option<String>,
#[serde(rename = "Heading")]
heading: Option<String>,
#[serde(rename = "RelatedTopics")]
related_topics: Vec<DuckDuckGoTopic>,
}
#[derive(Deserialize)]
struct DuckDuckGoTopic {
#[serde(rename = "Text")]
text: String,
#[serde(rename = "FirstURL")]
first_url: Option<String>,
}
impl WebSearchProvider {
#[must_use]
pub fn builder() -> WebSearchProviderBuilder {
WebSearchProviderBuilder::new()
}
}
#[derive(Debug)]
pub struct WebSearchProviderBuilder {
config: WebSearchConfig,
}
impl WebSearchProviderBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: WebSearchConfig::default(),
}
}
#[must_use]
pub fn with_config(mut self, config: WebSearchConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn add_engine(mut self, engine: SearchEngine) -> Self {
self.config.engines.push(engine);
self
}
#[must_use]
pub fn with_rate_limit(mut self, requests_per_second: u32) -> Self {
self.config.rate_limit_per_second = requests_per_second;
self
}
#[must_use]
pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
self.config.timeout_seconds = timeout_seconds;
self
}
#[must_use]
pub fn enable_caching(mut self, ttl_seconds: u64) -> Self {
self.config.enable_caching = true;
self.config.cache_ttl_seconds = ttl_seconds;
self
}
#[must_use]
pub fn disable_caching(mut self) -> Self {
self.config.enable_caching = false;
self
}
#[must_use]
pub fn build(self) -> WebSearchProvider {
WebSearchProvider::with_config(self.config)
}
}
impl Default for WebSearchProviderBuilder {
fn default() -> Self {
Self::new()
}
}