use crate::error::Result;
use crate::quantum::QuantumSearch;
use crate::search::reranker::AttentionReranker;
use crate::search::router::QueryRouter;
use crate::storage::vector_store::{SearchResult, VectorStore};
pub struct EnhancedSearch {
router: QueryRouter,
reranker: Option<AttentionReranker>,
quantum: Option<QuantumSearch>,
}
impl EnhancedSearch {
pub fn new(dim: usize) -> Self {
Self {
router: QueryRouter::new(),
reranker: Some(AttentionReranker::new(dim, 4)),
quantum: Some(QuantumSearch::new()),
}
}
pub fn router_only() -> Self {
Self {
router: QueryRouter::new(),
reranker: None,
quantum: None,
}
}
pub fn router(&self) -> &QueryRouter {
&self.router
}
pub fn search(
&self,
query: &str,
query_embedding: &[f32],
store: &VectorStore,
k: usize,
) -> Result<Vec<SearchResult>> {
let _route = self.router.route(query);
let candidate_k = (k * 3).max(10).min(store.len().max(1));
let candidates = store.search(query_embedding, candidate_k)?;
if candidates.is_empty() {
return Ok(Vec::new());
}
let results = if let Some(ref reranker) = self.reranker {
let reranker_input: Vec<(String, f32, Vec<f32>)> = candidates
.iter()
.map(|sr| {
let embedding = store
.get(&sr.id)
.map(|stored| stored.vector.clone())
.unwrap_or_else(|| vec![0.0; query_embedding.len()]);
(sr.id.to_string(), sr.score, embedding)
})
.collect();
let rerank_k = if self.quantum.is_some() {
(k * 2).min(reranker_input.len())
} else {
k
};
let reranked = reranker.rerank(query_embedding, &reranker_input, rerank_k);
let final_scored = if let Some(ref quantum) = self.quantum {
quantum.diversity_select(&reranked, k)
} else {
let mut r = reranked;
r.truncate(k);
r
};
final_scored
.into_iter()
.filter_map(|(id_str, score)| {
let uid: uuid::Uuid = id_str.parse().ok()?;
let original = candidates.iter().find(|c| c.id == uid)?;
Some(SearchResult {
id: uid,
score,
metadata: original.metadata.clone(),
})
})
.collect()
} else {
candidates.into_iter().take(k).collect()
};
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::StorageConfig;
use crate::capture::CapturedFrame;
use crate::storage::embedding::EmbeddingEngine;
#[test]
fn test_enhanced_search_empty_store() {
let config = StorageConfig::default();
let store = VectorStore::new(config).unwrap();
let engine = EmbeddingEngine::new(384);
let es = EnhancedSearch::new(384);
let query_emb = engine.embed("test query");
let results = es.search("test query", &query_emb, &store, 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_enhanced_search_returns_results() {
let config = StorageConfig::default();
let mut store = VectorStore::new(config).unwrap();
let engine = EmbeddingEngine::new(384);
let frames = vec![
CapturedFrame::new_screen("Editor", "code.rs", "implementing vector search in Rust", 0),
CapturedFrame::new_screen("Browser", "docs", "Rust vector database documentation", 0),
CapturedFrame::new_audio("Mic", "discussing Python machine learning", None),
];
for frame in &frames {
let emb = engine.embed(frame.text_content());
store.insert(frame, &emb).unwrap();
}
let es = EnhancedSearch::new(384);
let query_emb = engine.embed("vector search Rust");
let results = es.search("vector search Rust", &query_emb, &store, 2).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 2);
}
#[test]
fn test_enhanced_search_router_only() {
let config = StorageConfig::default();
let mut store = VectorStore::new(config).unwrap();
let engine = EmbeddingEngine::new(384);
let frame = CapturedFrame::new_screen("App", "Win", "test content", 0);
let emb = engine.embed(frame.text_content());
store.insert(&frame, &emb).unwrap();
let es = EnhancedSearch::router_only();
let query_emb = engine.embed("test content");
let results = es.search("test content", &query_emb, &store, 5).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_enhanced_search_respects_k() {
let config = StorageConfig::default();
let mut store = VectorStore::new(config).unwrap();
let engine = EmbeddingEngine::new(384);
for i in 0..10 {
let frame = CapturedFrame::new_screen("App", "Win", &format!("content {}", i), 0);
let emb = engine.embed(frame.text_content());
store.insert(&frame, &emb).unwrap();
}
let es = EnhancedSearch::new(384);
let query_emb = engine.embed("content");
let results = es.search("content", &query_emb, &store, 3).unwrap();
assert!(results.len() <= 3, "Should return at most k=3 results, got {}", results.len());
}
}