use tracing::{debug, info};
use crate::chunker::{chunk_document, ChunkConfig};
use crate::embedding::Embedder;
use crate::error::RagError;
use crate::metadata_filter::MetadataFilter;
use crate::vector_store::{SearchResult, VectorStore};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct RetrieverConfig {
pub top_k: usize,
pub min_score: f32,
pub rerank: bool,
}
impl Default for RetrieverConfig {
fn default() -> Self {
Self {
top_k: 5,
min_score: 0.0,
rerank: false,
}
}
}
impl RetrieverConfig {
#[must_use]
pub fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = top_k;
self
}
#[must_use]
pub fn with_min_score(mut self, min_score: f32) -> Self {
self.min_score = min_score;
self
}
#[must_use]
pub fn with_rerank(mut self, rerank: bool) -> Self {
self.rerank = rerank;
self
}
}
pub struct Retriever<E: Embedder> {
store: VectorStore,
embedder: E,
config: RetrieverConfig,
doc_count: usize,
}
impl<E: Embedder> Retriever<E> {
pub fn new(embedder: E, config: RetrieverConfig) -> Self {
let dim = embedder.embedding_dim();
Self {
store: VectorStore::new(dim),
embedder,
config,
doc_count: 0,
}
}
#[doc(hidden)]
pub fn from_parts(
embedder: E,
store: VectorStore,
doc_count: usize,
config: RetrieverConfig,
) -> Self {
Self {
store,
embedder,
config,
doc_count,
}
}
pub fn add_document(
&mut self,
text: &str,
chunk_config: &ChunkConfig,
) -> Result<usize, RagError> {
if text.trim().is_empty() {
return Err(RagError::EmptyDocument);
}
let doc_id = self.doc_count;
self.doc_count += 1;
let chunks = chunk_document(text, doc_id, chunk_config);
let mut indexed = 0usize;
for chunk in chunks {
let vector = self.embedder.embed(&chunk.text)?;
self.store.insert(vector, chunk)?;
indexed += 1;
}
debug!(doc_id, indexed, "document indexed");
Ok(indexed)
}
pub fn add_documents(
&mut self,
texts: &[&str],
chunk_config: &ChunkConfig,
) -> Result<Vec<usize>, RagError> {
let mut counts = Vec::with_capacity(texts.len());
for text in texts {
counts.push(self.add_document(text, chunk_config)?);
}
info!(
documents = texts.len(),
total_chunks = counts.iter().sum::<usize>(),
"batch indexing complete"
);
Ok(counts)
}
pub fn retrieve(&self, query: &str) -> Result<Vec<SearchResult>, RagError> {
if query.trim().is_empty() {
return Err(RagError::EmptyQuery);
}
if self.store.is_empty() {
return Err(RagError::NoDocumentsIndexed);
}
let query_vec = self.embedder.embed(query)?;
let mut results =
self.store
.search_with_threshold(&query_vec, self.config.top_k, self.config.min_score);
if self.config.rerank {
results = rerank(results, query);
}
debug!(
query_len = query.len(),
hits = results.len(),
"retrieval complete"
);
Ok(results)
}
pub fn retrieve_filtered(
&self,
query: &str,
filter: &MetadataFilter,
) -> Result<Vec<SearchResult>, RagError> {
if query.trim().is_empty() {
return Err(RagError::EmptyQuery);
}
if self.store.is_empty() {
return Err(RagError::NoDocumentsIndexed);
}
let query_vec = self.embedder.embed(query)?;
let mut results = self
.store
.search_filtered(&query_vec, self.config.top_k, filter)?;
if self.config.rerank {
results = rerank(results, query);
}
debug!(
query_len = query.len(),
hits = results.len(),
"filtered retrieval complete"
);
Ok(results)
}
pub fn retrieve_text(&self, query: &str) -> Result<Vec<String>, RagError> {
Ok(self
.retrieve(query)?
.into_iter()
.map(|r| r.chunk.text)
.collect())
}
pub fn document_count(&self) -> usize {
self.doc_count
}
pub fn chunk_count(&self) -> usize {
self.store.len()
}
pub fn store(&self) -> &VectorStore {
&self.store
}
pub fn embedder(&self) -> &E {
&self.embedder
}
}
fn rerank(mut results: Vec<SearchResult>, query: &str) -> Vec<SearchResult> {
let query_tokens: std::collections::HashSet<String> =
crate::embedding::tokenize(query).into_iter().collect();
for result in results.iter_mut() {
let chunk_tokens: std::collections::HashSet<String> =
crate::embedding::tokenize(&result.chunk.text)
.into_iter()
.collect();
let overlap = query_tokens.intersection(&chunk_tokens).count();
if overlap > 0 {
let boost = (overlap as f32 * 0.02).min(0.1);
result.score = (result.score + boost).min(1.0);
}
}
results.sort_unstable_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}