use rig_core::OneOrMany;
use rig_core::embeddings::EmbeddingModel;
use rig_core::vector_store::VectorSearchRequest;
use rig_core::vector_store::VectorStoreIndex;
use rig_core::vector_store::in_memory_store::{InMemoryVectorIndex, InMemoryVectorStore};
use rig_retrieval_evals::dataset::RetrievedDoc;
use rig_retrieval_evals::retriever::RetrieveFuture;
use std::collections::HashMap;
use super::Retriever;
use crate::engine;
use crate::error::{DciError, Result};
use crate::sandbox::CorpusRoot;
const CHUNK_CHARS: usize = 800;
const CHUNK_OVERLAP: usize = 120;
const RETRIEVE_OVERSAMPLE: usize = 8;
pub struct VectorRetriever<M: EmbeddingModel> {
index: InMemoryVectorIndex<M, String>,
name: String,
}
impl<M: EmbeddingModel + Sync> VectorRetriever<M> {
pub async fn build(corpus: &CorpusRoot, model: M) -> Result<Self> {
let paths = engine::list_files(corpus)?;
let mut documents = Vec::new();
for path in paths {
let text = engine::read_document(corpus, &path)?;
for (idx, chunk) in chunk_text(&text, CHUNK_CHARS, CHUNK_OVERLAP)
.into_iter()
.enumerate()
{
let embedding = model.embed_text(&chunk).await.map_err(|e| {
DciError::Worker(format!("embedding failed for {path}#{idx}: {e}"))
})?;
let chunk_id = format!("{path}#{idx}");
documents.push((chunk_id, path.clone(), OneOrMany::one(embedding)));
}
}
let store = InMemoryVectorStore::from_documents_with_ids(documents);
Ok(Self {
index: store.index(model),
name: "vector-baseline".to_string(),
})
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
}
impl<M: EmbeddingModel + Sync + 'static> Retriever for VectorRetriever<M> {
fn name(&self) -> &str {
&self.name
}
fn retrieve<'a>(&'a self, query: &'a str, k: usize) -> RetrieveFuture<'a> {
Box::pin(async move {
let samples = (k * RETRIEVE_OVERSAMPLE).max(k) as u64;
let request = VectorSearchRequest::builder()
.query(query.to_string())
.samples(samples)
.build();
let hits = self.index.top_n::<String>(request).await.map_err(|e| {
rig_retrieval_evals::Error::Config(format!("vector search failed: {e}"))
})?;
let mut best: HashMap<String, f64> = HashMap::new();
for (score, _chunk_id, path) in hits {
best.entry(path)
.and_modify(|s| {
if score > *s {
*s = score;
}
})
.or_insert(score);
}
let mut ranked: Vec<RetrievedDoc> = best
.into_iter()
.map(|(doc_id, score)| RetrievedDoc { doc_id, score })
.collect();
ranked.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.doc_id.cmp(&b.doc_id))
});
ranked.truncate(k);
Ok(ranked)
})
}
}
fn chunk_text(text: &str, size: usize, overlap: usize) -> Vec<String> {
let chars: Vec<char> = text.chars().collect();
if chars.len() <= size || size == 0 {
return vec![text.to_string()];
}
let step = size.saturating_sub(overlap).max(1);
let mut chunks = Vec::new();
let mut start = 0;
while start < chars.len() {
let end = (start + size).min(chars.len());
if let Some(window) = chars.get(start..end) {
chunks.push(window.iter().collect());
}
if end == chars.len() {
break;
}
start += step;
}
chunks
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic
)]
use super::*;
#[test]
fn short_text_is_a_single_chunk() {
let chunks = chunk_text("hello world", 800, 120);
assert_eq!(chunks, vec!["hello world".to_string()]);
}
#[test]
fn long_text_is_windowed_with_overlap() {
let text: String = "a".repeat(2000);
let chunks = chunk_text(&text, 800, 120);
assert!(chunks.len() >= 2, "2000 chars should split into >=2 chunks");
assert!(chunks.iter().all(|c| c.chars().count() <= 800));
}
}