#![cfg_attr(docsrs, feature(doc_cfg))]
pub mod config;
pub mod scraper;
pub mod utils;
#[cfg(feature = "backends")]
#[cfg_attr(docsrs, doc(cfg(feature = "backends")))]
pub mod backends;
#[cfg(feature = "llm")]
#[cfg_attr(docsrs, doc(cfg(feature = "llm")))]
pub mod llm;
pub use config::Config;
#[cfg(feature = "text-map")]
#[cfg_attr(docsrs, doc(cfg(feature = "text-map")))]
#[derive(Debug, Clone, serde::Serialize)]
pub struct TextNode {
pub id: usize,
pub text: String,
}
#[cfg(feature = "text-map")]
#[cfg_attr(docsrs, doc(cfg(feature = "text-map")))]
#[derive(Debug, Clone, serde::Serialize)]
pub struct TextMap {
pub nodes: Vec<TextNode>,
pub title: String,
}
#[cfg(feature = "text-map")]
#[cfg_attr(docsrs, doc(cfg(feature = "text-map")))]
#[derive(Debug, Clone)]
pub struct TextReplacement {
pub id: usize,
pub text: String,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct CleanResult {
pub text: String,
pub title: String,
pub truncated: bool,
pub char_count: usize,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct FetchResult {
pub url: String,
pub title: String,
pub text: String,
pub truncated: bool,
pub char_count: usize,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct Source {
pub id: usize,
pub title: String,
pub url: String,
pub snippet: Option<String>,
pub content: String,
pub truncated: bool,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct SnippetEntry {
pub id: usize,
pub title: String,
pub url: String,
pub snippet: String,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct Stats {
pub fetched: usize,
pub failed: usize,
pub gap_filled: usize,
pub total_chars: usize,
pub per_page_limit: usize,
pub num_results_per_query: usize,
pub raw_bytes: usize,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct QueryResult {
pub queries: Vec<String>,
pub sources: Vec<Source>,
pub snippet_pool: Vec<SnippetEntry>,
pub stats: Stats,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub llm_summary_error: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum WebshiftError {
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("Parse error: {0}")]
Parse(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("Backend error: {0}")]
Backend(String),
#[error("LLM error: {0}")]
Llm(String),
}
pub fn clean(raw_html: &str, max_chars: usize) -> CleanResult {
let (text, title, truncated) = scraper::cleaner::process_page(raw_html, "", max_chars);
let char_count = text.len();
CleanResult {
text,
title,
truncated,
char_count,
}
}
#[cfg(feature = "text-map")]
#[cfg_attr(docsrs, doc(cfg(feature = "text-map")))]
pub fn extract_text_nodes(raw_html: &str) -> TextMap {
let (nodes, title) = scraper::cleaner::extract_text_nodes(raw_html);
TextMap { nodes, title }
}
#[cfg(feature = "text-map")]
#[cfg_attr(docsrs, doc(cfg(feature = "text-map")))]
pub fn replace_text_nodes(
raw_html: &str,
replacements: &[TextReplacement],
) -> Result<String, WebshiftError> {
scraper::textmap::replace_text_nodes(raw_html, replacements)
}
pub async fn fetch(url: &str, config: &Config) -> Result<FetchResult, WebshiftError> {
if utils::url::is_binary_url(url) {
return Err(WebshiftError::Parse(format!(
"binary file URL filtered: {}",
url
)));
}
if !utils::url::is_domain_allowed(
url,
&config.server.blocked_domains,
&config.server.allowed_domains,
) {
return Err(WebshiftError::Parse(format!(
"URL blocked by domain filter: {}",
url
)));
}
let max_bytes = config.server.max_download_bytes();
let timeout = config.server.search_timeout;
let (html_map, _timing) =
scraper::fetcher::fetch_urls(&[url.to_string()], max_bytes, timeout).await;
let raw = match html_map.get(url) {
Some(h) => h.clone(),
None => {
return Err(WebshiftError::Parse(format!("fetch failed: {}", url)));
}
};
let max_chars = config.server.max_result_length;
let (text, title, truncated) = scraper::cleaner::process_page(&raw, "", max_chars);
let char_count = text.len();
Ok(FetchResult {
url: url.to_string(),
title,
text,
truncated,
char_count,
})
}
#[cfg(feature = "backends")]
#[cfg_attr(docsrs, doc(cfg(feature = "backends")))]
pub async fn query(queries: &[&str], config: &Config) -> Result<QueryResult, WebshiftError> {
query_with_options(queries, config, None, None, None).await
}
#[cfg(feature = "backends")]
#[cfg_attr(docsrs, doc(cfg(feature = "backends")))]
pub async fn query_with_options(
queries: &[&str],
config: &Config,
num_results_per_query: Option<usize>,
lang: Option<&str>,
backend_name: Option<&str>,
) -> Result<QueryResult, WebshiftError> {
use backends::{create_backend, create_backend_by_name, SearchResult as BackendResult};
let cfg = &config.server;
let backend = match backend_name {
Some(name) => create_backend_by_name(name, &config.backends)?,
None => create_backend(&config.backends)?,
};
let base_queries: Vec<String> = queries
.iter()
.take(cfg.max_search_queries)
.map(|s| s.to_string())
.collect();
#[cfg(feature = "llm")]
let queries_list: Vec<String> = if config.llm.enabled
&& config.llm.expansion_enabled
&& base_queries.len() == 1
{
let llm_client = llm::client::LlmClient::new(&config.llm);
let expanded = llm::expander::expand_queries(
&base_queries[0],
cfg.max_search_queries,
&llm_client,
)
.await;
expanded.into_iter().take(cfg.max_search_queries).collect()
} else {
base_queries
};
#[cfg(not(feature = "llm"))]
let queries_list: Vec<String> = base_queries;
if queries_list.is_empty() {
return Err(WebshiftError::Backend("no queries provided".into()));
}
let nrpq = num_results_per_query
.unwrap_or(cfg.results_per_query)
.min(cfg.max_total_results);
let total_results = (nrpq * queries_list.len()).min(cfg.max_total_results);
let oversample_count = nrpq * cfg.oversampling_factor as usize;
let resolved_lang: Option<&str> = lang.or_else(|| {
let l = cfg.language.as_str();
if l.is_empty() { None } else { Some(l) }
});
let search_futures: Vec<_> = queries_list
.iter()
.map(|q| backend.search(q, oversample_count, resolved_lang))
.collect();
let results_per_query = futures::future::join_all(search_futures).await;
let mut result_lists: Vec<Vec<BackendResult>> = Vec::new();
for r in results_per_query {
match r {
Ok(list) => result_lists.push(list),
Err(e) => {
tracing::warn!("backend search error: {e}");
}
}
}
let max_len = result_lists.iter().map(|l| l.len()).max().unwrap_or(0);
let mut raw_results: Vec<BackendResult> = Vec::new();
for i in 0..max_len {
for list in &result_lists {
if i < list.len() {
raw_results.push(list[i].clone());
}
}
}
let mut valid: Vec<BackendResult> = Vec::new();
let mut seen_urls: std::collections::HashSet<String> = std::collections::HashSet::new();
for r in &raw_results {
let clean = utils::url::sanitize_url(&r.url).to_lowercase();
let clean = clean.trim_end_matches('/').to_string();
if seen_urls.contains(&clean) || utils::url::is_binary_url(&r.url) {
continue;
}
if !utils::url::is_domain_allowed(&r.url, &cfg.blocked_domains, &cfg.allowed_domains) {
continue;
}
seen_urls.insert(clean);
valid.push(r.clone());
}
let candidates: Vec<BackendResult> = valid.iter().take(total_results).cloned().collect();
let mut reserve_pool: Vec<BackendResult> = valid.iter().skip(total_results).cloned().collect();
let candidate_urls: Vec<String> = candidates.iter().map(|r| r.url.clone()).collect();
let max_bytes = cfg.max_download_bytes();
let (mut html_map, mut fetch_timing) =
scraper::fetcher::fetch_urls(&candidate_urls, max_bytes, cfg.search_timeout).await;
let mut gap_filled: usize = 0;
let mut final_candidates = candidates.clone();
if cfg.auto_recovery_fetch && !reserve_pool.is_empty() {
let failed: Vec<&BackendResult> = candidates
.iter()
.filter(|r| !html_map.contains_key(&r.url))
.collect();
if !failed.is_empty() {
let gap_size = failed.len().min(reserve_pool.len());
let backups: Vec<BackendResult> = reserve_pool.drain(..gap_size).collect();
let backup_urls: Vec<String> = backups.iter().map(|r| r.url.clone()).collect();
let (backup_html, backup_timing) =
scraper::fetcher::fetch_urls(&backup_urls, max_bytes, cfg.search_timeout).await;
html_map.extend(backup_html);
fetch_timing.extend(backup_timing);
let mut new_candidates: Vec<BackendResult> = final_candidates
.iter()
.filter(|r| html_map.contains_key(&r.url))
.cloned()
.collect();
new_candidates.extend(backups);
let still_failed: Vec<BackendResult> = final_candidates
.iter()
.filter(|r| !html_map.contains_key(&r.url))
.cloned()
.collect();
reserve_pool = still_failed.into_iter().chain(reserve_pool).collect();
gap_filled = gap_size;
final_candidates = new_candidates;
}
}
let per_page_limit = cfg
.max_result_length
.min(cfg.max_query_budget / final_candidates.len().max(1));
let fetch_limit = if cfg.adaptive_budget != config::AdaptiveBudget::Off {
cfg.max_result_length * cfg.adaptive_budget_fetch_factor as usize
} else {
per_page_limit
};
let mut sources: Vec<Source> = Vec::new();
let mut fetched_count: usize = 0;
let mut failed_count: usize = 0;
for (idx, result) in final_candidates.iter().enumerate() {
let raw = html_map.get(&result.url);
if let Some(raw) = raw {
let (text, title, truncated) =
scraper::cleaner::process_page(raw, &result.snippet, fetch_limit);
fetched_count += 1;
let snippet = if !result.snippet.is_empty() && result.snippet != text {
Some(result.snippet.clone())
} else {
None
};
sources.push(Source {
id: idx + 1,
title: if title.is_empty() {
result.title.clone()
} else {
title
},
url: result.url.clone(),
snippet,
content: text,
truncated,
});
} else {
failed_count += 1;
let text = if result.snippet.is_empty() {
"[Fetch failed]".to_string()
} else {
result.snippet.clone()
};
sources.push(Source {
id: idx + 1,
title: result.title.clone(),
url: result.url.clone(),
snippet: None,
content: text,
truncated: false,
});
}
}
let (bm25_scores_opt, reranked) = match cfg.adaptive_budget {
config::AdaptiveBudget::Off => {
(None, utils::reranker::rerank_deterministic(&queries_list, &sources))
}
_ => {
let (scores, reranked) =
utils::reranker::rerank_with_scores(&queries_list, &sources);
(Some(scores), reranked)
}
};
sources = reranked;
let use_adaptive = match cfg.adaptive_budget {
config::AdaptiveBudget::On => true,
config::AdaptiveBudget::Off => false,
config::AdaptiveBudget::Auto => bm25_scores_opt.as_ref().is_some_and(|scores| {
let total: f64 = scores.iter().sum();
let max: f64 = scores.iter().cloned().fold(0.0_f64, f64::max);
let n = scores.len() as f64;
total > 0.0 && (max / total * n) > 1.5
}),
};
if use_adaptive {
let bm25_scores = bm25_scores_opt.unwrap();
let total_budget = cfg.max_query_budget;
let total_score: f64 = bm25_scores.iter().sum();
let mut allocs: Vec<usize> = if total_score > 0.0 {
bm25_scores
.iter()
.map(|&s| {
(s / total_score * total_budget as f64)
.round()
.max(200.0)
.min(fetch_limit as f64) as usize
})
.collect()
} else {
vec![total_budget / sources.len().max(1); sources.len()]
};
allocs = utils::reranker::redistribute_budget(&sources, &allocs, &bm25_scores);
for (source, &alloc) in sources.iter_mut().zip(allocs.iter()) {
if source.content.len() > alloc {
source.content = source.content.chars().take(alloc).collect();
source.truncated = true;
}
}
} else {
for source in &mut sources {
if source.content.len() > per_page_limit {
source.content = source.content.chars().take(per_page_limit).collect();
source.truncated = true;
}
}
}
#[cfg(feature = "llm")]
if config.llm.enabled && config.llm.llm_rerank_enabled {
let llm_client = llm::client::LlmClient::new(&config.llm);
sources = utils::reranker::rerank_llm(&queries_list, &sources, &llm_client).await;
}
for (i, source) in sources.iter_mut().enumerate() {
source.id = i + 1;
}
let snippet_pool: Vec<SnippetEntry> = reserve_pool
.iter()
.enumerate()
.map(|(i, r)| SnippetEntry {
id: sources.len() + i + 1,
title: r.title.clone(),
url: r.url.clone(),
snippet: r.snippet.clone(),
})
.collect();
let total_chars: usize = sources.iter().map(|s| s.content.len()).sum();
let raw_bytes: usize = fetch_timing.values().map(|(_, b)| b).sum();
#[cfg(feature = "llm")]
let (summary, llm_summary_error) = if config.llm.enabled && config.llm.summarization_enabled {
let llm_client = llm::client::LlmClient::new(&config.llm);
let max_words = if config.llm.max_summary_words > 0 {
config.llm.max_summary_words
} else {
cfg.max_query_budget / 5
};
match llm::summarizer::summarize_results(&queries_list, &sources, &llm_client, max_words)
.await
{
Ok(s) => (Some(s), None),
Err(e) => (None, Some(e.to_string())),
}
} else {
(None, None)
};
#[cfg(not(feature = "llm"))]
let (summary, llm_summary_error) = (None::<String>, None::<String>);
Ok(QueryResult {
queries: queries_list,
sources,
snippet_pool,
stats: Stats {
fetched: fetched_count,
failed: failed_count,
gap_filled,
total_chars,
per_page_limit,
num_results_per_query: nrpq,
raw_bytes,
},
summary,
llm_summary_error,
})
}
#[cfg(test)]
#[cfg(all(feature = "backends", feature = "llm"))]
mod llm_pipeline_tests {
use super::*;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn mock_config_with_llm(searxng_url: &str, llm_base_url: &str) -> Config {
let mut config = Config::default();
config.backends.searxng.url = searxng_url.to_string();
config.server.max_result_length = 4000;
config.server.max_query_budget = 16000;
config.server.max_total_results = 5;
config.server.search_timeout = 5;
config.llm.enabled = true;
config.llm.base_url = llm_base_url.to_string();
config.llm.model = "test-model".to_string();
config.llm.timeout = 5;
config
}
#[tokio::test]
async fn pipeline_with_query_expansion() {
let search_server = MockServer::start().await;
let page_server = MockServer::start().await;
let llm_server = MockServer::start().await;
let llm_body = serde_json::json!({
"choices": [{"message": {"content": "[\"rust async patterns\", \"tokio runtime tutorial\"]"}}]
});
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(&llm_body))
.mount(&llm_server)
.await;
let page_url = format!("{}/page1", page_server.uri());
let search_body = serde_json::json!({
"results": [{"title": "Rust", "url": &page_url, "content": "Rust async"}]
});
Mock::given(method("GET"))
.and(path("/search"))
.respond_with(ResponseTemplate::new(200).set_body_json(&search_body))
.mount(&search_server)
.await;
Mock::given(method("GET"))
.and(path("/page1"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"<html><body><p>Rust async programming content.</p></body></html>",
))
.mount(&page_server)
.await;
let mut config =
mock_config_with_llm(&search_server.uri(), &format!("{}/v1", llm_server.uri()));
config.llm.expansion_enabled = true;
config.llm.summarization_enabled = false;
config.llm.llm_rerank_enabled = false;
let result = query(&["rust"], &config).await.unwrap();
assert!(result.queries.len() >= 1);
assert_eq!(result.queries[0], "rust");
assert!(result.sources.len() >= 1);
}
#[tokio::test]
async fn pipeline_with_summarization() {
let search_server = MockServer::start().await;
let page_server = MockServer::start().await;
let llm_server = MockServer::start().await;
let llm_body = serde_json::json!({
"choices": [{"message": {"content": "## Summary\n\nRust is a systems language [1]."}}]
});
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(&llm_body))
.mount(&llm_server)
.await;
let page_url = format!("{}/page1", page_server.uri());
let search_body = serde_json::json!({
"results": [{"title": "Rust", "url": &page_url, "content": "Rust systems programming"}]
});
Mock::given(method("GET"))
.and(path("/search"))
.respond_with(ResponseTemplate::new(200).set_body_json(&search_body))
.mount(&search_server)
.await;
Mock::given(method("GET"))
.and(path("/page1"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"<html><body><p>Rust is a systems language.</p></body></html>",
))
.mount(&page_server)
.await;
let mut config =
mock_config_with_llm(&search_server.uri(), &format!("{}/v1", llm_server.uri()));
config.llm.expansion_enabled = false;
config.llm.summarization_enabled = true;
config.llm.llm_rerank_enabled = false;
let result = query(&["rust"], &config).await.unwrap();
assert!(result.summary.is_some(), "summary should be present");
let summary = result.summary.unwrap();
assert!(summary.contains("Summary") || summary.contains("Rust"));
assert!(result.llm_summary_error.is_none());
}
#[tokio::test]
async fn pipeline_summarization_error_is_captured() {
let search_server = MockServer::start().await;
let page_server = MockServer::start().await;
let llm_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(500))
.mount(&llm_server)
.await;
let page_url = format!("{}/page1", page_server.uri());
let search_body = serde_json::json!({
"results": [{"title": "Test", "url": &page_url, "content": "content"}]
});
Mock::given(method("GET"))
.and(path("/search"))
.respond_with(ResponseTemplate::new(200).set_body_json(&search_body))
.mount(&search_server)
.await;
Mock::given(method("GET"))
.and(path("/page1"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"<html><body><p>Content.</p></body></html>",
))
.mount(&page_server)
.await;
let mut config =
mock_config_with_llm(&search_server.uri(), &format!("{}/v1", llm_server.uri()));
config.llm.expansion_enabled = false;
config.llm.summarization_enabled = true;
config.llm.llm_rerank_enabled = false;
let result = query(&["test"], &config).await.unwrap();
assert!(result.summary.is_none());
assert!(result.llm_summary_error.is_some(), "should capture LLM error");
}
}
#[cfg(test)]
#[cfg(feature = "backends")]
mod pipeline_tests {
use super::*;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn mock_config(searxng_url: &str) -> Config {
let mut config = Config::default();
config.backends.searxng.url = searxng_url.to_string();
config.server.max_result_length = 4000;
config.server.max_query_budget = 16000;
config.server.max_total_results = 5;
config.server.search_timeout = 5;
config
}
#[tokio::test]
async fn pipeline_search_fetch_clean_rerank() {
let search_server = MockServer::start().await;
let page_server = MockServer::start().await;
let page_url_1 = format!("{}/page1", page_server.uri());
let page_url_2 = format!("{}/page2", page_server.uri());
let search_body = serde_json::json!({
"results": [
{"title": "Rust Programming", "url": &page_url_1, "content": "Learn Rust systems programming"},
{"title": "Tokio Async", "url": &page_url_2, "content": "Async runtime for Rust"},
]
});
Mock::given(method("GET"))
.and(path("/search"))
.respond_with(ResponseTemplate::new(200).set_body_json(&search_body))
.mount(&search_server)
.await;
Mock::given(method("GET"))
.and(path("/page1"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"<html><head><title>Rust Programming</title></head>\
<body><p>Rust is a systems programming language focused on safety.</p></body></html>",
))
.mount(&page_server)
.await;
Mock::given(method("GET"))
.and(path("/page2"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"<html><head><title>Tokio Tutorial</title></head>\
<body><p>Tokio is an async runtime for Rust applications.</p></body></html>",
))
.mount(&page_server)
.await;
let config = mock_config(&search_server.uri());
let result = query(&["rust programming"], &config).await.unwrap();
assert_eq!(result.queries, vec!["rust programming"]);
assert_eq!(result.sources.len(), 2);
assert_eq!(result.stats.fetched, 2);
assert_eq!(result.stats.failed, 0);
assert!(!result.sources[0].content.is_empty());
assert!(!result.sources[1].content.is_empty());
assert_eq!(result.sources[0].id, 1);
assert_eq!(result.sources[1].id, 2);
}
#[tokio::test]
async fn pipeline_handles_fetch_failure() {
let search_server = MockServer::start().await;
let page_server = MockServer::start().await;
let page_url = format!("{}/good", page_server.uri());
let search_body = serde_json::json!({
"results": [
{"title": "Good Page", "url": &page_url, "content": "Good snippet"},
{"title": "Bad Page", "url": "http://192.0.2.1:1/nonexistent", "content": "Will fail"},
]
});
Mock::given(method("GET"))
.and(path("/search"))
.respond_with(ResponseTemplate::new(200).set_body_json(&search_body))
.mount(&search_server)
.await;
Mock::given(method("GET"))
.and(path("/good"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"<html><body><p>Good content here.</p></body></html>",
))
.mount(&page_server)
.await;
let config = mock_config(&search_server.uri());
let result = query(&["test"], &config).await.unwrap();
assert_eq!(result.stats.fetched, 1);
assert_eq!(result.stats.failed, 1);
assert_eq!(result.sources.len(), 2);
}
#[tokio::test]
async fn pipeline_deduplicates_urls() {
let search_server = MockServer::start().await;
let page_server = MockServer::start().await;
let page_url = format!("{}/page", page_server.uri());
let search_body = serde_json::json!({
"results": [
{"title": "Page", "url": &page_url, "content": "Snippet 1"},
{"title": "Page Dup", "url": &page_url, "content": "Snippet 2"},
]
});
Mock::given(method("GET"))
.and(path("/search"))
.respond_with(ResponseTemplate::new(200).set_body_json(&search_body))
.mount(&search_server)
.await;
Mock::given(method("GET"))
.and(path("/page"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("<html><body><p>Page content</p></body></html>"),
)
.mount(&page_server)
.await;
let config = mock_config(&search_server.uri());
let result = query(&["test"], &config).await.unwrap();
assert_eq!(result.sources.len(), 1);
assert_eq!(result.stats.fetched, 1);
}
#[tokio::test]
async fn pipeline_filters_binary_urls() {
let search_server = MockServer::start().await;
let page_server = MockServer::start().await;
let good_url = format!("{}/page", page_server.uri());
let search_body = serde_json::json!({
"results": [
{"title": "Good", "url": &good_url, "content": "Good page"},
{"title": "PDF", "url": "https://example.com/file.pdf", "content": "A PDF"},
{"title": "ZIP", "url": "https://example.com/file.zip", "content": "A ZIP"},
]
});
Mock::given(method("GET"))
.and(path("/search"))
.respond_with(ResponseTemplate::new(200).set_body_json(&search_body))
.mount(&search_server)
.await;
Mock::given(method("GET"))
.and(path("/page"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("<html><body><p>Content</p></body></html>"),
)
.mount(&page_server)
.await;
let config = mock_config(&search_server.uri());
let result = query(&["test"], &config).await.unwrap();
assert_eq!(result.sources.len(), 1);
assert!(result.sources[0].url.contains("/page"));
}
#[tokio::test]
async fn pipeline_multiple_queries_round_robin() {
let search_server = MockServer::start().await;
let page_server = MockServer::start().await;
let url1 = format!("{}/rust", page_server.uri());
let url2 = format!("{}/tokio", page_server.uri());
let search_body = serde_json::json!({
"results": [
{"title": "Rust", "url": &url1, "content": "Rust lang"},
{"title": "Tokio", "url": &url2, "content": "Async Rust"},
]
});
Mock::given(method("GET"))
.and(path("/search"))
.respond_with(ResponseTemplate::new(200).set_body_json(&search_body))
.mount(&search_server)
.await;
for p in ["/rust", "/tokio"] {
Mock::given(method("GET"))
.and(path(p))
.respond_with(
ResponseTemplate::new(200)
.set_body_string(&format!("<html><body><p>{p} content</p></body></html>")),
)
.mount(&page_server)
.await;
}
let config = mock_config(&search_server.uri());
let result = query(&["rust", "async"], &config).await.unwrap();
assert!(result.sources.len() >= 1);
assert_eq!(result.queries, vec!["rust", "async"]);
}
#[tokio::test]
async fn pipeline_unknown_backend_returns_error() {
let config = Config::default();
let result = query_with_options(&["test"], &config, None, None, Some("nonexistent")).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("unknown backend"));
}
#[tokio::test]
async fn pipeline_empty_queries_returns_error() {
let config = Config::default();
let result = query(&[], &config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn pipeline_snippet_pool_contains_reserves() {
let search_server = MockServer::start().await;
let page_server = MockServer::start().await;
let mut results = Vec::new();
for i in 0..8 {
let url = format!("{}/page{i}", page_server.uri());
results.push(serde_json::json!({
"title": format!("Page {i}"),
"url": url,
"content": format!("Snippet {i}"),
}));
}
let search_body = serde_json::json!({"results": results});
Mock::given(method("GET"))
.and(path("/search"))
.respond_with(ResponseTemplate::new(200).set_body_json(&search_body))
.mount(&search_server)
.await;
for i in 0..8 {
Mock::given(method("GET"))
.and(path(format!("/page{i}")))
.respond_with(ResponseTemplate::new(200).set_body_string(&format!(
"<html><body><p>Content for page {i}</p></body></html>"
)))
.mount(&page_server)
.await;
}
let mut config = mock_config(&search_server.uri());
config.server.max_total_results = 3;
config.server.results_per_query = 3;
config.server.oversampling_factor = 3;
let result = query(&["test"], &config).await.unwrap();
assert_eq!(result.sources.len(), 3);
assert!(result.snippet_pool.len() > 0, "snippet pool should have reserves");
}
}
#[cfg(test)]
mod clean_tests {
use super::*;
#[test]
fn clean_returns_correct_fields() {
let result = clean("<html><body><p>hello world</p></body></html>", 8000);
assert!(
result.text.contains("hello world"),
"text should contain 'hello world', got: {}",
result.text
);
assert_eq!(result.char_count, result.text.len());
assert!(!result.truncated);
}
#[test]
fn clean_truncated_flag() {
let result = clean(
"<html><body><p>hello world this is long content</p></body></html>",
5,
);
assert!(result.truncated, "should be truncated with max_chars=5");
assert!(
result.char_count <= 5,
"char_count ({}) should be <= 5",
result.char_count
);
}
#[test]
fn clean_empty_html() {
let result = clean("", 8000);
assert!(
result.text.is_empty(),
"empty HTML should produce empty text, got: {:?}",
result.text
);
assert_eq!(result.char_count, 0);
assert!(!result.truncated);
}
#[test]
fn clean_with_noise_elements() {
let html = r#"<html><body>
<nav>Navigation menu</nav>
<script>alert('xss')</script>
<p>Real content here</p>
<footer>Footer stuff</footer>
</body></html>"#;
let result = clean(html, 8000);
assert!(
result.text.contains("Real content here"),
"should keep real content"
);
assert!(
!result.text.contains("alert"),
"should strip script content"
);
assert!(
!result.text.contains("Navigation menu"),
"should strip nav content"
);
assert!(
!result.text.contains("Footer stuff"),
"should strip footer content"
);
}
}
#[cfg(test)]
mod fetch_tests {
use super::*;
#[tokio::test]
async fn fetch_binary_url_rejected() {
let cfg = Config::default();
let result = fetch("https://example.com/file.pdf", &cfg).await;
assert!(result.is_err(), "binary URL should be rejected");
let err = result.unwrap_err().to_string();
assert!(
err.contains("binary file URL filtered"),
"error should mention binary filter, got: {err}"
);
}
#[tokio::test]
async fn fetch_blocked_domain_rejected() {
let mut cfg = Config::default();
cfg.server.blocked_domains = vec!["blocked.example.com".to_string()];
let result = fetch("https://blocked.example.com/page", &cfg).await;
assert!(result.is_err(), "blocked domain should be rejected");
let err = result.unwrap_err().to_string();
assert!(
err.contains("blocked by domain filter"),
"error should mention domain filter, got: {err}"
);
}
#[tokio::test]
async fn fetch_allowed_domain_not_matching() {
let mut cfg = Config::default();
cfg.server.allowed_domains = vec!["allowed.example.com".to_string()];
let result = fetch("https://other.example.com/page", &cfg).await;
assert!(result.is_err(), "non-allowed domain should be rejected");
let err = result.unwrap_err().to_string();
assert!(
err.contains("blocked by domain filter"),
"error should mention domain filter, got: {err}"
);
}
#[tokio::test]
async fn fetch_returns_correct_fields() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/test-page"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"<html><head><title>Test Title</title></head>\
<body><p>Test body content for fetching.</p></body></html>",
))
.mount(&server)
.await;
let cfg = Config::default();
let url = format!("{}/test-page", server.uri());
let result = fetch(&url, &cfg).await.unwrap();
assert_eq!(result.url, url);
assert!(
result.text.contains("Test body content"),
"text should contain page body"
);
assert_eq!(result.char_count, result.text.len());
assert!(
result.title.contains("Test Title"),
"title should be extracted"
);
}
}
#[cfg(test)]
#[cfg(feature = "backends")]
mod query_edge_tests {
use super::*;
#[tokio::test]
async fn query_empty_queries_returns_error() {
let cfg = Config::default();
let result = query_with_options(&[], &cfg, None, None, None).await;
assert!(result.is_err(), "empty queries should return an error");
}
}