1use crate::error::{Error, Result};
6use crate::filter::{CompiledFilter, FilterExpr};
7use crate::index::{
8 IndexRegistry, MultiIndexResults, ParallelSearcher, SearchResult,
9 rrf_fuse,
10};
11use crate::retrieval::rerank::{Reranker, RerankerConfig};
12use crate::stats::OutcomeStats;
13use crate::store::RecordStore;
14use crate::types::{MemoryRecord, PriorBundle, RecordId};
15use std::time::{Duration, Instant};
16
17#[derive(Debug, Clone)]
19pub struct QueryEngineConfig {
20 pub default_k: usize,
22 pub max_k: usize,
24 pub timeout_ms: u64,
26 pub parallel_search: bool,
28 pub reranker: Option<RerankerConfig>,
30 pub build_priors: bool,
32}
33
34impl Default for QueryEngineConfig {
35 fn default() -> Self {
36 Self {
37 default_k: 10,
38 max_k: 1000,
39 timeout_ms: 5000,
40 parallel_search: true,
41 reranker: None,
42 build_priors: true,
43 }
44 }
45}
46
47impl QueryEngineConfig {
48 #[must_use]
50 pub fn new() -> Self {
51 Self::default()
52 }
53
54 #[must_use]
56 pub const fn with_default_k(mut self, k: usize) -> Self {
57 self.default_k = k;
58 self
59 }
60
61 #[must_use]
63 pub const fn with_timeout_ms(mut self, ms: u64) -> Self {
64 self.timeout_ms = ms;
65 self
66 }
67
68 #[must_use]
70 pub fn with_reranker(mut self, config: RerankerConfig) -> Self {
71 self.reranker = Some(config);
72 self
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct QueryRequest {
79 pub embedding: Vec<f32>,
81 pub k: Option<usize>,
83 pub filter: Option<FilterExpr>,
85 pub indexes: Option<Vec<String>>,
87 pub timeout_ms: Option<u64>,
89}
90
91impl QueryRequest {
92 #[must_use]
94 pub fn new(embedding: Vec<f32>) -> Self {
95 Self {
96 embedding,
97 k: None,
98 filter: None,
99 indexes: None,
100 timeout_ms: None,
101 }
102 }
103
104 #[must_use]
106 pub const fn with_k(mut self, k: usize) -> Self {
107 self.k = Some(k);
108 self
109 }
110
111 #[must_use]
113 pub fn with_filter(mut self, filter: FilterExpr) -> Self {
114 self.filter = Some(filter);
115 self
116 }
117
118 #[must_use]
120 pub fn with_indexes(mut self, indexes: Vec<String>) -> Self {
121 self.indexes = Some(indexes);
122 self
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct RetrievedRecord {
129 pub record: MemoryRecord,
131 pub score: f32,
133 pub rank: usize,
135 pub source_index: String,
137}
138
139#[derive(Debug, Clone)]
141pub struct QueryResponse {
142 pub results: Vec<RetrievedRecord>,
144 pub priors: Option<PriorBundle>,
146 pub latency: Duration,
148 pub indexes_searched: usize,
150 pub candidates_considered: usize,
152}
153
154impl QueryResponse {
155 #[must_use]
157 pub fn top(&self) -> Option<&RetrievedRecord> {
158 self.results.first()
159 }
160
161 #[must_use]
163 pub fn is_empty(&self) -> bool {
164 self.results.is_empty()
165 }
166
167 #[must_use]
169 pub fn len(&self) -> usize {
170 self.results.len()
171 }
172}
173
174pub struct QueryEngine<'a, S: RecordStore> {
183 config: QueryEngineConfig,
185 registry: &'a IndexRegistry,
187 store: &'a S,
189 reranker: Option<Reranker>,
191}
192
193impl<'a, S: RecordStore> QueryEngine<'a, S> {
194 #[must_use]
196 pub fn new(
197 config: QueryEngineConfig,
198 registry: &'a IndexRegistry,
199 store: &'a S,
200 ) -> Self {
201 let reranker = config.reranker.clone().map(Reranker::new);
202 Self {
203 config,
204 registry,
205 store,
206 reranker,
207 }
208 }
209
210 pub fn query(&self, request: QueryRequest) -> Result<QueryResponse> {
216 let start = Instant::now();
217 let timeout = Duration::from_millis(
218 request.timeout_ms.unwrap_or(self.config.timeout_ms),
219 );
220
221 self.validate_query(&request)?;
223
224 let k = request.k.unwrap_or(self.config.default_k).min(self.config.max_k);
226
227 let (search_results, indexes_searched) = self.execute_search(&request, k)?;
229
230 if start.elapsed() > timeout {
232 return Err(Error::QueryTimeout {
233 elapsed_ms: start.elapsed().as_millis() as u64,
234 budget_ms: timeout.as_millis() as u64,
235 });
236 }
237
238 let mut results = self.build_results(search_results, &request)?;
240 let candidates_considered = results.len();
241
242 if let Some(ref filter_expr) = request.filter {
244 let filter = CompiledFilter::compile(filter_expr.clone());
245 results.retain(|r| filter.evaluate(&r.record.metadata));
246 }
247
248 if let Some(ref reranker) = self.reranker {
250 results = reranker.rerank(results);
251 }
252
253 results.truncate(k);
255
256 for (i, result) in results.iter_mut().enumerate() {
258 result.rank = i + 1;
259 }
260
261 let priors = if self.config.build_priors && !results.is_empty() {
263 Some(self.build_priors(&results))
264 } else {
265 None
266 };
267
268 Ok(QueryResponse {
269 results,
270 priors,
271 latency: start.elapsed(),
272 indexes_searched,
273 candidates_considered,
274 })
275 }
276
277 fn validate_query(&self, request: &QueryRequest) -> Result<()> {
279 if request.embedding.is_empty() {
280 return Err(Error::InvalidQuery {
281 reason: "Empty embedding".into(),
282 });
283 }
284
285 if let Some(k) = request.k {
286 if k == 0 {
287 return Err(Error::InvalidQuery {
288 reason: "k must be > 0".into(),
289 });
290 }
291 if k > self.config.max_k {
292 return Err(Error::InvalidQuery {
293 reason: format!("k exceeds maximum ({})", self.config.max_k),
294 });
295 }
296 }
297
298 let dim = request.embedding.len();
300 let has_compatible = self.registry.info().iter().any(|i| i.dimension == dim);
301
302 if !has_compatible {
303 return Err(Error::InvalidQuery {
304 reason: format!("No index with dimension {dim}"),
305 });
306 }
307
308 Ok(())
309 }
310
311 fn execute_search(
313 &self,
314 request: &QueryRequest,
315 k: usize,
316 ) -> Result<(Vec<(String, SearchResult)>, usize)> {
317 let query = &request.embedding;
318
319 let multi_results: MultiIndexResults = if let Some(ref index_names) = request.indexes {
321 let names: Vec<&str> = index_names.iter().map(String::as_str).collect();
323 if self.config.parallel_search && names.len() > 1 {
324 let searcher = ParallelSearcher::new(self.registry);
325 searcher.search_indexes_parallel(&names, query, k)?
326 } else {
327 self.registry.search_indexes(&names, query, k)?
328 }
329 } else {
330 if self.config.parallel_search {
332 let searcher = ParallelSearcher::new(self.registry);
333 searcher.search_parallel(query, k)?
334 } else {
335 self.registry.search_all(query, k)?
336 }
337 };
338
339 let indexes_searched = multi_results.by_index.len();
340
341 let results: Vec<(String, SearchResult)> = if indexes_searched > 1 {
343 let fused = rrf_fuse(&multi_results);
344 fused
345 .into_iter()
346 .map(|f| {
347 let source = f.sources.first().cloned().unwrap_or_default();
348 (
349 source,
350 SearchResult {
351 id: f.id,
352 distance: 0.0, score: f.fused_score,
354 },
355 )
356 })
357 .collect()
358 } else {
359 multi_results.flatten()
360 };
361
362 Ok((results, indexes_searched))
363 }
364
365 fn build_results(
367 &self,
368 search_results: Vec<(String, SearchResult)>,
369 _request: &QueryRequest,
370 ) -> Result<Vec<RetrievedRecord>> {
371 let mut results = Vec::with_capacity(search_results.len());
372
373 for (index_name, sr) in search_results {
374 let id: RecordId = sr.id.into();
375
376 if let Some(record) = self.store.get(&id) {
377 results.push(RetrievedRecord {
378 record,
379 score: sr.score,
380 rank: 0, source_index: index_name,
382 });
383 }
384 }
385
386 Ok(results)
387 }
388
389 fn build_priors(&self, results: &[RetrievedRecord]) -> PriorBundle {
391 let mut stats = OutcomeStats::new(1);
392
393 for result in results {
394 stats.update_scalar(result.record.outcome);
395 if result.record.stats.dim() == 1 {
397 stats = stats.merge(&result.record.stats);
398 }
399 }
400
401 let mean = stats.mean_scalar().unwrap_or(0.0);
402 let std_dev = stats.std_scalar().unwrap_or(0.0);
403 let ci = stats.confidence_interval(0.95)
404 .map(|(l, u)| (l[0] as f64, u[0] as f64))
405 .unwrap_or((mean, mean));
406
407 PriorBundle {
408 mean_outcome: mean,
409 std_outcome: std_dev,
410 confidence_interval: ci,
411 sample_count: stats.count(),
412 prototype_ids: results.iter().take(3).map(|r| r.record.id.clone()).collect(),
413 }
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use crate::index::{FlatIndex, IndexConfig, VectorIndex};
421 use crate::store::InMemoryStore;
422 use crate::types::RecordStatus;
423 use crate::OutcomeStats;
424
425 fn create_test_record(id: &str, embedding: Vec<f32>) -> MemoryRecord {
426 MemoryRecord {
427 id: id.into(),
428 embedding,
429 context: format!("Context for {id}"),
430 outcome: 0.8,
431 metadata: Default::default(),
432 created_at: 1234567890,
433 status: RecordStatus::Active,
434 stats: OutcomeStats::new(1),
435 }
436 }
437
438 fn setup_test_env() -> (IndexRegistry, InMemoryStore) {
439 let mut registry = IndexRegistry::new();
440 let mut store = InMemoryStore::new();
441
442 let mut index = FlatIndex::new(IndexConfig::new(4));
444
445 for i in 0..10 {
447 let embedding = vec![i as f32, 0.0, 0.0, 0.0];
448 let record = create_test_record(&format!("rec-{i}"), embedding.clone());
449
450 index.add(record.id.to_string(), &embedding).unwrap();
451 store.insert(record).unwrap();
452 }
453
454 registry.register("test", index).unwrap();
455 (registry, store)
456 }
457
458 #[test]
459 fn test_basic_query() {
460 let (registry, store) = setup_test_env();
461 let engine = QueryEngine::new(
462 QueryEngineConfig::new(),
463 ®istry,
464 &store,
465 );
466
467 let request = QueryRequest::new(vec![5.0, 0.0, 0.0, 0.0]).with_k(3);
468 let response = engine.query(request).unwrap();
469
470 assert_eq!(response.len(), 3);
471 assert!(!response.is_empty());
472 assert!(response.priors.is_some());
473 }
474
475 #[test]
476 fn test_query_validation_empty_embedding() {
477 let (registry, store) = setup_test_env();
478 let engine = QueryEngine::new(
479 QueryEngineConfig::new(),
480 ®istry,
481 &store,
482 );
483
484 let request = QueryRequest::new(vec![]);
485 let result = engine.query(request);
486
487 assert!(result.is_err());
488 }
489
490 #[test]
491 fn test_query_validation_k_zero() {
492 let (registry, store) = setup_test_env();
493 let engine = QueryEngine::new(
494 QueryEngineConfig::new(),
495 ®istry,
496 &store,
497 );
498
499 let request = QueryRequest::new(vec![1.0, 0.0, 0.0, 0.0]).with_k(0);
500 let result = engine.query(request);
501
502 assert!(result.is_err());
503 }
504
505 #[test]
506 fn test_query_with_priors() {
507 let (registry, store) = setup_test_env();
508 let config = QueryEngineConfig::new();
509 let engine = QueryEngine::new(config, ®istry, &store);
510
511 let request = QueryRequest::new(vec![5.0, 0.0, 0.0, 0.0]).with_k(5);
512 let response = engine.query(request).unwrap();
513
514 let priors = response.priors.unwrap();
515 assert!(priors.sample_count > 0);
516 assert!(!priors.prototype_ids.is_empty());
517 }
518
519 #[test]
520 fn test_multi_index_query() {
521 let mut registry = IndexRegistry::new();
522 let mut store = InMemoryStore::new();
523
524 let mut index1 = FlatIndex::new(IndexConfig::new(4));
526 let mut index2 = FlatIndex::new(IndexConfig::new(4));
527
528 let rec1 = create_test_record("rec-a", vec![1.0, 0.0, 0.0, 0.0]);
530 index1.add(rec1.id.to_string(), &rec1.embedding).unwrap();
531 store.insert(rec1).unwrap();
532
533 let rec2 = create_test_record("rec-b", vec![0.0, 1.0, 0.0, 0.0]);
535 index2.add(rec2.id.to_string(), &rec2.embedding).unwrap();
536 store.insert(rec2).unwrap();
537
538 registry.register("idx1", index1).unwrap();
539 registry.register("idx2", index2).unwrap();
540
541 let engine = QueryEngine::new(
542 QueryEngineConfig::new(),
543 ®istry,
544 &store,
545 );
546
547 let request = QueryRequest::new(vec![0.5, 0.5, 0.0, 0.0]).with_k(5);
548 let response = engine.query(request).unwrap();
549
550 assert_eq!(response.indexes_searched, 2);
551 assert_eq!(response.len(), 2);
552 }
553
554 #[test]
555 fn test_response_latency() {
556 let (registry, store) = setup_test_env();
557 let engine = QueryEngine::new(
558 QueryEngineConfig::new(),
559 ®istry,
560 &store,
561 );
562
563 let request = QueryRequest::new(vec![5.0, 0.0, 0.0, 0.0]).with_k(3);
564 let response = engine.query(request).unwrap();
565
566 assert!(response.latency.as_micros() > 0);
567 }
568}