use crate::error::{Error, Result};
use crate::filter::{CompiledFilter, FilterExpr};
use crate::index::{
IndexRegistry, MultiIndexResults, ParallelSearcher, SearchResult,
rrf_fuse,
};
use crate::retrieval::rerank::{Reranker, RerankerConfig};
use crate::stats::OutcomeStats;
use crate::store::RecordStore;
use crate::types::{MemoryRecord, PriorBundle, RecordId};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct QueryEngineConfig {
pub default_k: usize,
pub max_k: usize,
pub timeout_ms: u64,
pub parallel_search: bool,
pub reranker: Option<RerankerConfig>,
pub build_priors: bool,
}
impl Default for QueryEngineConfig {
fn default() -> Self {
Self {
default_k: 10,
max_k: 1000,
timeout_ms: 5000,
parallel_search: true,
reranker: None,
build_priors: true,
}
}
}
impl QueryEngineConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_default_k(mut self, k: usize) -> Self {
self.default_k = k;
self
}
#[must_use]
pub const fn with_timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = ms;
self
}
#[must_use]
pub fn with_reranker(mut self, config: RerankerConfig) -> Self {
self.reranker = Some(config);
self
}
}
#[derive(Debug, Clone)]
pub struct QueryRequest {
pub embedding: Vec<f32>,
pub k: Option<usize>,
pub filter: Option<FilterExpr>,
pub indexes: Option<Vec<String>>,
pub timeout_ms: Option<u64>,
}
impl QueryRequest {
#[must_use]
pub fn new(embedding: Vec<f32>) -> Self {
Self {
embedding,
k: None,
filter: None,
indexes: None,
timeout_ms: None,
}
}
#[must_use]
pub const fn with_k(mut self, k: usize) -> Self {
self.k = Some(k);
self
}
#[must_use]
pub fn with_filter(mut self, filter: FilterExpr) -> Self {
self.filter = Some(filter);
self
}
#[must_use]
pub fn with_indexes(mut self, indexes: Vec<String>) -> Self {
self.indexes = Some(indexes);
self
}
}
#[derive(Debug, Clone)]
pub struct RetrievedRecord {
pub record: MemoryRecord,
pub score: f32,
pub rank: usize,
pub source_index: String,
}
#[derive(Debug, Clone)]
pub struct QueryResponse {
pub results: Vec<RetrievedRecord>,
pub priors: Option<PriorBundle>,
pub latency: Duration,
pub indexes_searched: usize,
pub candidates_considered: usize,
}
impl QueryResponse {
#[must_use]
pub fn top(&self) -> Option<&RetrievedRecord> {
self.results.first()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.results.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.results.len()
}
}
pub struct QueryEngine<'a, S: RecordStore> {
config: QueryEngineConfig,
registry: &'a IndexRegistry,
store: &'a S,
reranker: Option<Reranker>,
}
impl<'a, S: RecordStore> QueryEngine<'a, S> {
#[must_use]
pub fn new(
config: QueryEngineConfig,
registry: &'a IndexRegistry,
store: &'a S,
) -> Self {
let reranker = config.reranker.clone().map(Reranker::new);
Self {
config,
registry,
store,
reranker,
}
}
pub fn query(&self, request: QueryRequest) -> Result<QueryResponse> {
let start = Instant::now();
let timeout = Duration::from_millis(
request.timeout_ms.unwrap_or(self.config.timeout_ms),
);
self.validate_query(&request)?;
let k = request.k.unwrap_or(self.config.default_k).min(self.config.max_k);
let (search_results, indexes_searched) = self.execute_search(&request, k)?;
if start.elapsed() > timeout {
return Err(Error::QueryTimeout {
elapsed_ms: start.elapsed().as_millis() as u64,
budget_ms: timeout.as_millis() as u64,
});
}
let mut results = self.build_results(search_results, &request)?;
let candidates_considered = results.len();
if let Some(ref filter_expr) = request.filter {
let filter = CompiledFilter::compile(filter_expr.clone());
results.retain(|r| filter.evaluate(&r.record.metadata));
}
if let Some(ref reranker) = self.reranker {
results = reranker.rerank(results);
}
results.truncate(k);
for (i, result) in results.iter_mut().enumerate() {
result.rank = i + 1;
}
let priors = if self.config.build_priors && !results.is_empty() {
Some(self.build_priors(&results))
} else {
None
};
Ok(QueryResponse {
results,
priors,
latency: start.elapsed(),
indexes_searched,
candidates_considered,
})
}
fn validate_query(&self, request: &QueryRequest) -> Result<()> {
if request.embedding.is_empty() {
return Err(Error::InvalidQuery {
reason: "Empty embedding".into(),
});
}
if let Some(k) = request.k {
if k == 0 {
return Err(Error::InvalidQuery {
reason: "k must be > 0".into(),
});
}
if k > self.config.max_k {
return Err(Error::InvalidQuery {
reason: format!("k exceeds maximum ({})", self.config.max_k),
});
}
}
let dim = request.embedding.len();
let has_compatible = self.registry.info().iter().any(|i| i.dimension == dim);
if !has_compatible {
return Err(Error::InvalidQuery {
reason: format!("No index with dimension {dim}"),
});
}
Ok(())
}
fn execute_search(
&self,
request: &QueryRequest,
k: usize,
) -> Result<(Vec<(String, SearchResult)>, usize)> {
let query = &request.embedding;
let multi_results: MultiIndexResults = if let Some(ref index_names) = request.indexes {
let names: Vec<&str> = index_names.iter().map(String::as_str).collect();
if self.config.parallel_search && names.len() > 1 {
let searcher = ParallelSearcher::new(self.registry);
searcher.search_indexes_parallel(&names, query, k)?
} else {
self.registry.search_indexes(&names, query, k)?
}
} else {
if self.config.parallel_search {
let searcher = ParallelSearcher::new(self.registry);
searcher.search_parallel(query, k)?
} else {
self.registry.search_all(query, k)?
}
};
let indexes_searched = multi_results.by_index.len();
let results: Vec<(String, SearchResult)> = if indexes_searched > 1 {
let fused = rrf_fuse(&multi_results);
fused
.into_iter()
.map(|f| {
let source = f.sources.first().cloned().unwrap_or_default();
(
source,
SearchResult {
id: f.id,
distance: 0.0, score: f.fused_score,
},
)
})
.collect()
} else {
multi_results.flatten()
};
Ok((results, indexes_searched))
}
fn build_results(
&self,
search_results: Vec<(String, SearchResult)>,
_request: &QueryRequest,
) -> Result<Vec<RetrievedRecord>> {
let mut results = Vec::with_capacity(search_results.len());
for (index_name, sr) in search_results {
let id: RecordId = sr.id.into();
if let Some(record) = self.store.get(&id) {
results.push(RetrievedRecord {
record,
score: sr.score,
rank: 0, source_index: index_name,
});
}
}
Ok(results)
}
fn build_priors(&self, results: &[RetrievedRecord]) -> PriorBundle {
let mut stats = OutcomeStats::new(1);
for result in results {
stats.update_scalar(result.record.outcome);
if result.record.stats.dim() == 1 {
stats = stats.merge(&result.record.stats);
}
}
let mean = stats.mean_scalar().unwrap_or(0.0);
let std_dev = stats.std_scalar().unwrap_or(0.0);
let ci = stats.confidence_interval(0.95)
.map(|(l, u)| (l[0] as f64, u[0] as f64))
.unwrap_or((mean, mean));
PriorBundle {
mean_outcome: mean,
std_outcome: std_dev,
confidence_interval: ci,
sample_count: stats.count(),
prototype_ids: results.iter().take(3).map(|r| r.record.id.clone()).collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::{FlatIndex, IndexConfig, VectorIndex};
use crate::store::InMemoryStore;
use crate::types::RecordStatus;
use crate::OutcomeStats;
fn create_test_record(id: &str, embedding: Vec<f32>) -> MemoryRecord {
MemoryRecord {
id: id.into(),
embedding,
context: format!("Context for {id}"),
outcome: 0.8,
metadata: Default::default(),
created_at: 1234567890,
status: RecordStatus::Active,
stats: OutcomeStats::new(1),
}
}
fn setup_test_env() -> (IndexRegistry, InMemoryStore) {
let mut registry = IndexRegistry::new();
let mut store = InMemoryStore::new();
let mut index = FlatIndex::new(IndexConfig::new(4));
for i in 0..10 {
let embedding = vec![i as f32, 0.0, 0.0, 0.0];
let record = create_test_record(&format!("rec-{i}"), embedding.clone());
index.add(record.id.to_string(), &embedding).unwrap();
store.insert(record).unwrap();
}
registry.register("test", index).unwrap();
(registry, store)
}
#[test]
fn test_basic_query() {
let (registry, store) = setup_test_env();
let engine = QueryEngine::new(
QueryEngineConfig::new(),
®istry,
&store,
);
let request = QueryRequest::new(vec![5.0, 0.0, 0.0, 0.0]).with_k(3);
let response = engine.query(request).unwrap();
assert_eq!(response.len(), 3);
assert!(!response.is_empty());
assert!(response.priors.is_some());
}
#[test]
fn test_query_validation_empty_embedding() {
let (registry, store) = setup_test_env();
let engine = QueryEngine::new(
QueryEngineConfig::new(),
®istry,
&store,
);
let request = QueryRequest::new(vec![]);
let result = engine.query(request);
assert!(result.is_err());
}
#[test]
fn test_query_validation_k_zero() {
let (registry, store) = setup_test_env();
let engine = QueryEngine::new(
QueryEngineConfig::new(),
®istry,
&store,
);
let request = QueryRequest::new(vec![1.0, 0.0, 0.0, 0.0]).with_k(0);
let result = engine.query(request);
assert!(result.is_err());
}
#[test]
fn test_query_with_priors() {
let (registry, store) = setup_test_env();
let config = QueryEngineConfig::new();
let engine = QueryEngine::new(config, ®istry, &store);
let request = QueryRequest::new(vec![5.0, 0.0, 0.0, 0.0]).with_k(5);
let response = engine.query(request).unwrap();
let priors = response.priors.unwrap();
assert!(priors.sample_count > 0);
assert!(!priors.prototype_ids.is_empty());
}
#[test]
fn test_multi_index_query() {
let mut registry = IndexRegistry::new();
let mut store = InMemoryStore::new();
let mut index1 = FlatIndex::new(IndexConfig::new(4));
let mut index2 = FlatIndex::new(IndexConfig::new(4));
let rec1 = create_test_record("rec-a", vec![1.0, 0.0, 0.0, 0.0]);
index1.add(rec1.id.to_string(), &rec1.embedding).unwrap();
store.insert(rec1).unwrap();
let rec2 = create_test_record("rec-b", vec![0.0, 1.0, 0.0, 0.0]);
index2.add(rec2.id.to_string(), &rec2.embedding).unwrap();
store.insert(rec2).unwrap();
registry.register("idx1", index1).unwrap();
registry.register("idx2", index2).unwrap();
let engine = QueryEngine::new(
QueryEngineConfig::new(),
®istry,
&store,
);
let request = QueryRequest::new(vec![0.5, 0.5, 0.0, 0.0]).with_k(5);
let response = engine.query(request).unwrap();
assert_eq!(response.indexes_searched, 2);
assert_eq!(response.len(), 2);
}
#[test]
fn test_response_latency() {
let (registry, store) = setup_test_env();
let engine = QueryEngine::new(
QueryEngineConfig::new(),
®istry,
&store,
);
let request = QueryRequest::new(vec![5.0, 0.0, 0.0, 0.0]).with_k(3);
let response = engine.query(request).unwrap();
assert!(response.latency.as_micros() > 0);
}
}