dci-tool 0.1.0

Direct Corpus Interaction: a sandboxed, ripgrep-backed corpus-search toolset and agent for cyber-focused LLM agents, built on rig.
Documentation
//! Vector-embedding baseline retriever for head-to-head comparison.
//!
//! Builds an in-memory vector index over the same corpus and answers the same
//! queries, so its retrieval quality can be scored with the identical metrics
//! as the DCI retriever. It is generic over any rig [`EmbeddingModel`], so the
//! caller brings their own embeddings (a provider model in production, or a
//! deterministic offline model in tests).

use rig_core::OneOrMany;
use rig_core::embeddings::EmbeddingModel;
use rig_core::vector_store::VectorSearchRequest;
use rig_core::vector_store::VectorStoreIndex;
use rig_core::vector_store::in_memory_store::{InMemoryVectorIndex, InMemoryVectorStore};
use rig_retrieval_evals::dataset::RetrievedDoc;
use rig_retrieval_evals::retriever::RetrieveFuture;
use std::collections::HashMap;

use super::Retriever;
use crate::engine;
use crate::error::{DciError, Result};
use crate::sandbox::CorpusRoot;

/// Characters per embedding chunk.
const CHUNK_CHARS: usize = 800;
/// Overlap between consecutive chunks, to avoid splitting evidence across a
/// boundary.
const CHUNK_OVERLAP: usize = 120;
/// How many chunk hits to request per query before collapsing to file
/// granularity — oversampling so enough distinct files survive deduplication.
const RETRIEVE_OVERSAMPLE: usize = 8;

/// A retriever backed by an in-memory vector store of embedded corpus files.
///
/// Each document id is the file's corpus-relative path, matching the ids the
/// [`DciRetriever`](super::DciRetriever) produces, so both can be scored
/// against the same qrels.
pub struct VectorRetriever<M: EmbeddingModel> {
    index: InMemoryVectorIndex<M, String>,
    name: String,
}

impl<M: EmbeddingModel + Sync> VectorRetriever<M> {
    /// Build the index by reading, chunking, and embedding every file in
    /// `corpus`.
    ///
    /// Each file is split into overlapping character windows
    /// ([`CHUNK_CHARS`] with [`CHUNK_OVERLAP`]) and each chunk is embedded
    /// separately, so long documents are represented fairly rather than crushed
    /// into a single averaged vector. Every chunk stores its parent file path as
    /// its document, and chunk hits are collapsed back to file granularity at
    /// query time to match the document-level qrels.
    pub async fn build(corpus: &CorpusRoot, model: M) -> Result<Self> {
        let paths = engine::list_files(corpus)?;
        let mut documents = Vec::new();
        for path in paths {
            let text = engine::read_document(corpus, &path)?;
            for (idx, chunk) in chunk_text(&text, CHUNK_CHARS, CHUNK_OVERLAP)
                .into_iter()
                .enumerate()
            {
                let embedding = model.embed_text(&chunk).await.map_err(|e| {
                    DciError::Worker(format!("embedding failed for {path}#{idx}: {e}"))
                })?;
                // Unique chunk id; the stored document is the parent file path.
                let chunk_id = format!("{path}#{idx}");
                documents.push((chunk_id, path.clone(), OneOrMany::one(embedding)));
            }
        }

        let store = InMemoryVectorStore::from_documents_with_ids(documents);
        Ok(Self {
            index: store.index(model),
            name: "vector-baseline".to_string(),
        })
    }

    /// Override the report store label.
    pub fn with_name(mut self, name: impl Into<String>) -> Self {
        self.name = name.into();
        self
    }
}

impl<M: EmbeddingModel + Sync + 'static> Retriever for VectorRetriever<M> {
    fn name(&self) -> &str {
        &self.name
    }

    fn retrieve<'a>(&'a self, query: &'a str, k: usize) -> RetrieveFuture<'a> {
        Box::pin(async move {
            // Oversample chunk hits, then collapse to one entry per parent file,
            // keeping each file's best-scoring chunk.
            let samples = (k * RETRIEVE_OVERSAMPLE).max(k) as u64;
            let request = VectorSearchRequest::builder()
                .query(query.to_string())
                .samples(samples)
                .build();
            // `top_n::<String>` returns (score, chunk_id, parent_file_path).
            let hits = self.index.top_n::<String>(request).await.map_err(|e| {
                rig_retrieval_evals::Error::Config(format!("vector search failed: {e}"))
            })?;

            let mut best: HashMap<String, f64> = HashMap::new();
            for (score, _chunk_id, path) in hits {
                best.entry(path)
                    .and_modify(|s| {
                        if score > *s {
                            *s = score;
                        }
                    })
                    .or_insert(score);
            }

            let mut ranked: Vec<RetrievedDoc> = best
                .into_iter()
                .map(|(doc_id, score)| RetrievedDoc { doc_id, score })
                .collect();
            ranked.sort_by(|a, b| {
                b.score
                    .partial_cmp(&a.score)
                    .unwrap_or(std::cmp::Ordering::Equal)
                    .then_with(|| a.doc_id.cmp(&b.doc_id))
            });
            ranked.truncate(k);
            Ok(ranked)
        })
    }
}

/// Split `text` into overlapping character windows. Returns a single chunk when
/// the text is shorter than `size`. UTF-8 safe (operates on `char` boundaries).
fn chunk_text(text: &str, size: usize, overlap: usize) -> Vec<String> {
    let chars: Vec<char> = text.chars().collect();
    if chars.len() <= size || size == 0 {
        return vec![text.to_string()];
    }
    let step = size.saturating_sub(overlap).max(1);
    let mut chunks = Vec::new();
    let mut start = 0;
    while start < chars.len() {
        let end = (start + size).min(chars.len());
        if let Some(window) = chars.get(start..end) {
            chunks.push(window.iter().collect());
        }
        if end == chars.len() {
            break;
        }
        start += step;
    }
    chunks
}

#[cfg(test)]
mod tests {
    #![allow(
        clippy::unwrap_used,
        clippy::expect_used,
        clippy::indexing_slicing,
        clippy::panic
    )]
    use super::*;

    #[test]
    fn short_text_is_a_single_chunk() {
        let chunks = chunk_text("hello world", 800, 120);
        assert_eq!(chunks, vec!["hello world".to_string()]);
    }

    #[test]
    fn long_text_is_windowed_with_overlap() {
        let text: String = "a".repeat(2000);
        let chunks = chunk_text(&text, 800, 120);
        assert!(chunks.len() >= 2, "2000 chars should split into >=2 chunks");
        // Every chunk respects the size bound.
        assert!(chunks.iter().all(|c| c.chars().count() <= 800));
    }
}