use std::time::{Duration, Instant};
use crate::error::{Error, Result};
use crate::filter::{CompiledFilter, FilterExpr};
use crate::index::{SearchResult, VectorIndex};
use crate::retrieval::rerank::{Reranker, RerankerConfig};
use crate::sparse::{BM25Index, HybridFusionStrategy, HybridResult, HybridSearchConfig, HybridSearcher};
use crate::stats::OutcomeStats;
use crate::store::RecordStore;
use crate::types::{MemoryRecord, PriorBundle, RecordId};
#[derive(Debug, Clone)]
pub struct HybridQueryEngineConfig {
pub default_k: usize,
pub max_k: usize,
pub timeout_ms: u64,
pub hybrid_config: HybridSearchConfig,
pub reranker: Option<RerankerConfig>,
pub build_priors: bool,
pub fallback_to_dense: bool,
}
impl Default for HybridQueryEngineConfig {
fn default() -> Self {
Self {
default_k: 10,
max_k: 1000,
timeout_ms: 5000,
hybrid_config: HybridSearchConfig::default(),
reranker: None,
build_priors: true,
fallback_to_dense: true,
}
}
}
impl HybridQueryEngineConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_default_k(mut self, k: usize) -> Self {
self.default_k = k;
self
}
#[must_use]
pub const fn with_timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = ms;
self
}
#[must_use]
pub fn with_fusion_strategy(mut self, strategy: HybridFusionStrategy) -> Self {
self.hybrid_config.strategy = strategy;
self
}
#[must_use]
pub fn with_candidates_per_index(mut self, n: usize) -> Self {
self.hybrid_config.candidates_per_index = n;
self
}
#[must_use]
pub fn with_reranker(mut self, config: RerankerConfig) -> Self {
self.reranker = Some(config);
self
}
}
#[derive(Debug, Clone)]
pub struct HybridQueryRequest {
pub embedding: Vec<f32>,
pub text_query: Option<String>,
pub k: Option<usize>,
pub filter: Option<FilterExpr>,
pub timeout_ms: Option<u64>,
pub fusion_strategy: Option<HybridFusionStrategy>,
}
impl HybridQueryRequest {
#[must_use]
pub fn new(embedding: Vec<f32>) -> Self {
Self {
embedding,
text_query: None,
k: None,
filter: None,
timeout_ms: None,
fusion_strategy: None,
}
}
#[must_use]
pub fn with_text(mut self, text: impl Into<String>) -> Self {
self.text_query = Some(text.into());
self
}
#[must_use]
pub const fn with_k(mut self, k: usize) -> Self {
self.k = Some(k);
self
}
#[must_use]
pub fn with_filter(mut self, filter: FilterExpr) -> Self {
self.filter = Some(filter);
self
}
#[must_use]
pub const fn with_timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = Some(ms);
self
}
#[must_use]
pub fn with_fusion_strategy(mut self, strategy: HybridFusionStrategy) -> Self {
self.fusion_strategy = Some(strategy);
self
}
#[must_use]
pub fn is_hybrid(&self) -> bool {
self.text_query.is_some()
}
}
#[derive(Debug, Clone)]
pub struct HybridRetrievedRecord {
pub record: MemoryRecord,
pub score: f32,
pub dense_score: Option<f32>,
pub sparse_score: Option<f32>,
pub rank: usize,
}
#[derive(Debug, Clone)]
pub struct HybridQueryResponse {
pub results: Vec<HybridRetrievedRecord>,
pub priors: Option<PriorBundle>,
pub latency: Duration,
pub used_hybrid: bool,
pub candidates_considered: usize,
}
impl HybridQueryResponse {
#[must_use]
pub fn top(&self) -> Option<&HybridRetrievedRecord> {
self.results.first()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.results.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.results.len()
}
}
pub struct HybridQueryEngine<'a, I: VectorIndex, S: RecordStore> {
config: HybridQueryEngineConfig,
dense_index: &'a I,
sparse_index: &'a BM25Index,
store: &'a S,
reranker: Option<Reranker>,
}
impl<'a, I: VectorIndex, S: RecordStore> HybridQueryEngine<'a, I, S> {
#[must_use]
pub fn new(
config: HybridQueryEngineConfig,
dense_index: &'a I,
sparse_index: &'a BM25Index,
store: &'a S,
) -> Self {
let reranker = config.reranker.clone().map(Reranker::new);
Self {
config,
dense_index,
sparse_index,
store,
reranker,
}
}
pub fn query(&self, request: HybridQueryRequest) -> Result<HybridQueryResponse> {
let start = Instant::now();
let timeout = Duration::from_millis(
request.timeout_ms.unwrap_or(self.config.timeout_ms),
);
self.validate_query(&request)?;
let k = request.k.unwrap_or(self.config.default_k).min(self.config.max_k);
let (hybrid_results, used_hybrid) = self.execute_search(&request, k)?;
if start.elapsed() > timeout {
return Err(Error::QueryTimeout {
elapsed_ms: start.elapsed().as_millis() as u64,
budget_ms: timeout.as_millis() as u64,
});
}
let mut results = self.build_results(hybrid_results)?;
let candidates_considered = results.len();
if let Some(ref filter_expr) = request.filter {
let filter = CompiledFilter::compile(filter_expr.clone());
results.retain(|r| filter.evaluate(&r.record.metadata));
}
if let Some(ref reranker) = self.reranker {
results = self.rerank_results(reranker, results);
}
results.truncate(k);
for (i, result) in results.iter_mut().enumerate() {
result.rank = i + 1;
}
let priors = if self.config.build_priors && !results.is_empty() {
Some(self.build_priors(&results))
} else {
None
};
Ok(HybridQueryResponse {
results,
priors,
latency: start.elapsed(),
used_hybrid,
candidates_considered,
})
}
fn validate_query(&self, request: &HybridQueryRequest) -> Result<()> {
if request.embedding.is_empty() {
return Err(Error::InvalidQuery {
reason: "Empty embedding".into(),
});
}
let dim = self.dense_index.dimension();
if dim > 0 && request.embedding.len() != dim {
return Err(Error::InvalidQuery {
reason: format!(
"Dimension mismatch: query has {}, index expects {}",
request.embedding.len(),
dim
),
});
}
if let Some(k) = request.k {
if k == 0 {
return Err(Error::InvalidQuery {
reason: "k must be > 0".into(),
});
}
if k > self.config.max_k {
return Err(Error::InvalidQuery {
reason: format!("k exceeds maximum ({})", self.config.max_k),
});
}
}
Ok(())
}
fn execute_search(
&self,
request: &HybridQueryRequest,
k: usize,
) -> Result<(Vec<HybridResult>, bool)> {
let mut hybrid_config = self.config.hybrid_config.clone();
if let Some(strategy) = request.fusion_strategy {
hybrid_config.strategy = strategy;
}
let searcher = HybridSearcher::new(self.dense_index, self.sparse_index)
.with_config(hybrid_config);
if let Some(ref text_query) = request.text_query {
let results = searcher.search(&request.embedding, text_query, k)?;
Ok((results, true))
} else if self.config.fallback_to_dense {
let results = searcher.search_dense_only(&request.embedding, k)?;
Ok((results, false))
} else {
Err(Error::InvalidQuery {
reason: "Text query required for hybrid search".into(),
})
}
}
fn build_results(
&self,
hybrid_results: Vec<HybridResult>,
) -> Result<Vec<HybridRetrievedRecord>> {
let mut results = Vec::with_capacity(hybrid_results.len());
for hr in hybrid_results {
let id: RecordId = hr.id.into();
if let Some(record) = self.store.get(&id) {
results.push(HybridRetrievedRecord {
record,
score: hr.score,
dense_score: hr.dense_score,
sparse_score: hr.sparse_score,
rank: 0, });
}
}
Ok(results)
}
fn rerank_results(
&self,
reranker: &Reranker,
results: Vec<HybridRetrievedRecord>,
) -> Vec<HybridRetrievedRecord> {
use crate::retrieval::engine::RetrievedRecord;
let converted: Vec<RetrievedRecord> = results
.iter()
.map(|r| RetrievedRecord {
record: r.record.clone(),
score: r.score,
rank: r.rank,
source_index: "hybrid".to_string(),
})
.collect();
let reranked = reranker.rerank(converted);
reranked
.into_iter()
.map(|rr| {
let original = results.iter().find(|r| r.record.id == rr.record.id);
HybridRetrievedRecord {
record: rr.record,
score: rr.score,
dense_score: original.and_then(|o| o.dense_score),
sparse_score: original.and_then(|o| o.sparse_score),
rank: rr.rank,
}
})
.collect()
}
fn build_priors(&self, results: &[HybridRetrievedRecord]) -> PriorBundle {
let mut stats = OutcomeStats::new(1);
for result in results {
stats.update_scalar(result.record.outcome);
if result.record.stats.dim() == 1 {
stats = stats.merge(&result.record.stats);
}
}
let mean = stats.mean_scalar().unwrap_or(0.0);
let std_dev = stats.std_scalar().unwrap_or(0.0);
let ci = stats.confidence_interval(0.95)
.map(|(l, u)| (l[0] as f64, u[0] as f64))
.unwrap_or((mean, mean));
PriorBundle {
mean_outcome: mean,
std_outcome: std_dev,
confidence_interval: ci,
sample_count: stats.count(),
prototype_ids: results.iter().take(3).map(|r| r.record.id.clone()).collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::{FlatIndex, IndexConfig};
use crate::store::InMemoryStore;
use crate::types::RecordStatus;
fn create_test_record(id: &str, embedding: Vec<f32>, text: &str) -> MemoryRecord {
MemoryRecord {
id: id.into(),
embedding,
context: text.to_string(),
outcome: 0.8,
metadata: Default::default(),
created_at: 1234567890,
status: RecordStatus::Active,
stats: OutcomeStats::new(1),
}
}
fn setup_test_env() -> (FlatIndex, BM25Index, InMemoryStore) {
let mut dense_index = FlatIndex::new(IndexConfig::new(4));
let mut sparse_index = BM25Index::new();
let mut store = InMemoryStore::new();
let records = vec![
("rec-1", vec![1.0, 0.0, 0.0, 0.0], "machine learning algorithms"),
("rec-2", vec![0.0, 1.0, 0.0, 0.0], "deep neural networks"),
("rec-3", vec![0.0, 0.0, 1.0, 0.0], "natural language processing"),
("rec-4", vec![1.0, 1.0, 0.0, 0.0], "reinforcement learning agents"),
("rec-5", vec![0.0, 1.0, 1.0, 0.0], "computer vision models"),
];
for (id, embedding, text) in records {
let record = create_test_record(id, embedding.clone(), text);
dense_index.add(record.id.to_string(), &embedding).unwrap();
sparse_index.add(record.id.to_string(), text);
store.insert(record).unwrap();
}
(dense_index, sparse_index, store)
}
#[test]
fn test_hybrid_query_basic() {
let (dense, sparse, store) = setup_test_env();
let config = HybridQueryEngineConfig::new();
let engine = HybridQueryEngine::new(config, &dense, &sparse, &store);
let request = HybridQueryRequest::new(vec![1.0, 0.0, 0.0, 0.0])
.with_text("machine learning")
.with_k(3);
let response = engine.query(request).unwrap();
assert!(response.used_hybrid);
assert!(!response.is_empty());
assert!(response.len() <= 3);
let top = response.top().unwrap();
assert!(top.dense_score.is_some() || top.sparse_score.is_some());
}
#[test]
fn test_dense_only_fallback() {
let (dense, sparse, store) = setup_test_env();
let config = HybridQueryEngineConfig::new();
let engine = HybridQueryEngine::new(config, &dense, &sparse, &store);
let request = HybridQueryRequest::new(vec![1.0, 0.0, 0.0, 0.0])
.with_k(3);
let response = engine.query(request).unwrap();
assert!(!response.used_hybrid);
assert!(!response.is_empty());
for result in &response.results {
assert!(result.dense_score.is_some());
assert!(result.sparse_score.is_none());
}
}
#[test]
fn test_hybrid_with_rrf() {
let (dense, sparse, store) = setup_test_env();
let config = HybridQueryEngineConfig::new()
.with_fusion_strategy(HybridFusionStrategy::RRF { k: 60.0 });
let engine = HybridQueryEngine::new(config, &dense, &sparse, &store);
let request = HybridQueryRequest::new(vec![1.0, 0.0, 0.0, 0.0])
.with_text("learning")
.with_k(5);
let response = engine.query(request).unwrap();
assert!(response.used_hybrid);
assert!(!response.is_empty());
}
#[test]
fn test_hybrid_with_linear_fusion() {
let (dense, sparse, store) = setup_test_env();
let config = HybridQueryEngineConfig::new()
.with_fusion_strategy(HybridFusionStrategy::Linear { alpha: 0.7 });
let engine = HybridQueryEngine::new(config, &dense, &sparse, &store);
let request = HybridQueryRequest::new(vec![0.0, 1.0, 0.0, 0.0])
.with_text("neural networks")
.with_k(3);
let response = engine.query(request).unwrap();
assert!(response.used_hybrid);
let ids: Vec<_> = response.results.iter().map(|r| r.record.id.to_string()).collect();
assert!(ids.contains(&"rec-2".to_string()));
}
#[test]
fn test_query_strategy_override() {
let (dense, sparse, store) = setup_test_env();
let config = HybridQueryEngineConfig::new()
.with_fusion_strategy(HybridFusionStrategy::RRF { k: 60.0 });
let engine = HybridQueryEngine::new(config, &dense, &sparse, &store);
let request = HybridQueryRequest::new(vec![1.0, 0.0, 0.0, 0.0])
.with_text("machine learning")
.with_k(3)
.with_fusion_strategy(HybridFusionStrategy::Linear { alpha: 0.5 });
let response = engine.query(request).unwrap();
assert!(response.used_hybrid);
assert!(!response.is_empty());
}
#[test]
fn test_hybrid_with_priors() {
let (dense, sparse, store) = setup_test_env();
let config = HybridQueryEngineConfig::new();
let engine = HybridQueryEngine::new(config, &dense, &sparse, &store);
let request = HybridQueryRequest::new(vec![1.0, 0.0, 0.0, 0.0])
.with_text("learning")
.with_k(5);
let response = engine.query(request).unwrap();
assert!(response.priors.is_some());
let priors = response.priors.unwrap();
assert!(priors.sample_count > 0);
}
#[test]
fn test_empty_query_validation() {
let (dense, sparse, store) = setup_test_env();
let config = HybridQueryEngineConfig::new();
let engine = HybridQueryEngine::new(config, &dense, &sparse, &store);
let request = HybridQueryRequest::new(vec![]);
let result = engine.query(request);
assert!(result.is_err());
}
#[test]
fn test_k_zero_validation() {
let (dense, sparse, store) = setup_test_env();
let config = HybridQueryEngineConfig::new();
let engine = HybridQueryEngine::new(config, &dense, &sparse, &store);
let request = HybridQueryRequest::new(vec![1.0, 0.0, 0.0, 0.0])
.with_k(0);
let result = engine.query(request);
assert!(result.is_err());
}
#[test]
fn test_response_latency() {
let (dense, sparse, store) = setup_test_env();
let config = HybridQueryEngineConfig::new();
let engine = HybridQueryEngine::new(config, &dense, &sparse, &store);
let request = HybridQueryRequest::new(vec![1.0, 0.0, 0.0, 0.0])
.with_text("machine")
.with_k(3);
let response = engine.query(request).unwrap();
assert!(response.latency.as_micros() > 0);
}
}