use crate::config::VyctorConfig;
use crate::embeddings::{create_provider, EmbeddingProvider};
use crate::reranker::{create_reranker, DocumentToRerank, Reranker};
use crate::storage::{SearchResult, Storage};
use anyhow::Result;
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, Instant};
pub struct SearchEngine {
storage: Storage,
embedder: Arc<dyn EmbeddingProvider>,
reranker: Option<Arc<dyn Reranker>>,
rerank_top_k: usize,
}
#[derive(Debug, Clone)]
pub struct SearchTiming {
pub embed_time: Duration,
pub search_time: Duration,
pub rerank_time: Option<Duration>,
pub total_time: Duration,
}
impl SearchEngine {
pub fn new(root: &Path, config: &VyctorConfig, verbose: bool) -> Result<Self> {
let db_path = root.join(".vyctor").join("index.duckdb");
let storage = Storage::new(&db_path, config.embedding.dimensions)?;
let embedder = create_provider(&config.embedding, verbose)?;
let reranker = create_reranker(&config.reranker)?;
let rerank_top_k = config.reranker.top_k;
Ok(Self {
storage,
embedder,
reranker,
rerank_top_k,
})
}
pub async fn search(
&self,
query: &str,
limit: usize,
folder_filter: Option<&str>,
) -> Result<(Vec<SearchResult>, SearchTiming)> {
let total_start = Instant::now();
let embed_start = Instant::now();
let query_result = self.embedder.embed(query).await?;
let embed_time = embed_start.elapsed();
let initial_limit = if self.reranker.is_some() {
self.rerank_top_k.max(limit)
} else {
limit
};
let search_start = Instant::now();
let mut results =
self.storage
.search(&query_result.embedding, initial_limit, folder_filter)?;
let search_time = search_start.elapsed();
let rerank_time = if let Some(ref reranker) = self.reranker {
let rerank_start = Instant::now();
let documents: Vec<DocumentToRerank> = results
.iter()
.enumerate()
.map(|(i, r)| DocumentToRerank {
id: i,
content: r.chunk_content.clone(),
})
.collect();
let rerank_results = reranker.rerank(query, documents).await?;
let reranked_results: Vec<SearchResult> = rerank_results
.into_iter()
.take(limit)
.map(|rr| {
let mut result = results[rr.id].clone();
result.score = rr.relevance_score;
result
})
.collect();
results = reranked_results;
Some(rerank_start.elapsed())
} else {
results.truncate(limit);
None
};
let timing = SearchTiming {
embed_time,
search_time,
rerank_time,
total_time: total_start.elapsed(),
};
Ok((results, timing))
}
#[allow(dead_code)]
pub fn has_reranker(&self) -> bool {
self.reranker.is_some()
}
pub fn reranker_model(&self) -> Option<&str> {
self.reranker.as_ref().map(|r| r.model_name())
}
#[allow(dead_code)]
pub fn storage(&self) -> &Storage {
&self.storage
}
}
#[cfg(test)]
mod tests {
}