pub mod agent;
pub mod beir;
pub mod synth;
mod vector;
use std::collections::{HashMap, HashSet};
use rig_retrieval_evals::dataset::{Qrels, RetrievedDoc};
use rig_retrieval_evals::report::MultiReport;
use rig_retrieval_evals::retrieval::{
HitRateAtK, MapAtK, Mrr, NdcgAtK, PrecisionAtK, RecallAtK, RetrievalMetric,
};
use rig_retrieval_evals::retriever::RetrieveFuture;
use rig_retrieval_evals::score_retriever;
use crate::engine::{self, SearchQuery};
use crate::error::DciError;
use crate::sandbox::CorpusRoot;
pub use rig_retrieval_evals::Retriever;
pub use vector::VectorRetriever;
#[derive(Debug, Clone)]
pub struct EvalConfig {
pub k: usize,
pub bootstrap_iters: usize,
pub ci_level: f64,
pub seed: u64,
pub dataset_id: String,
}
impl Default for EvalConfig {
fn default() -> Self {
Self {
k: 10,
bootstrap_iters: 1000,
ci_level: 0.95,
seed: 42,
dataset_id: "dci-eval".to_string(),
}
}
}
impl EvalConfig {
fn metrics(&self) -> Vec<Box<dyn RetrievalMetric>> {
vec![
Box::new(RecallAtK::new(self.k)),
Box::new(PrecisionAtK::new(self.k)),
Box::new(HitRateAtK::new(self.k)),
Box::new(Mrr),
Box::new(MapAtK::new(self.k)),
Box::new(NdcgAtK::new(self.k)),
]
}
}
pub async fn evaluate(
retriever: &dyn Retriever,
qrels: &Qrels,
cfg: &EvalConfig,
) -> Result<MultiReport, DciError> {
let metrics = cfg.metrics();
let report = score_retriever(retriever, qrels, cfg.k, &metrics, 1)
.await
.map_err(|e| DciError::Worker(e.to_string()))?;
let report = if cfg.bootstrap_iters > 0 {
report.with_bootstrap(cfg.bootstrap_iters, cfg.ci_level, cfg.seed)
} else {
report
};
Ok(report
.with_dataset(cfg.dataset_id.clone())
.with_store(retriever.name().to_string()))
}
pub struct Comparison {
pub dci: MultiReport,
pub baseline: MultiReport,
}
impl Comparison {
pub async fn run(
dci: &dyn Retriever,
baseline: &dyn Retriever,
qrels: &Qrels,
cfg: &EvalConfig,
) -> Result<Self, DciError> {
Ok(Self {
dci: evaluate(dci, qrels, cfg).await?,
baseline: evaluate(baseline, qrels, cfg).await?,
})
}
pub fn to_markdown(&self) -> String {
let mut out = String::new();
out.push_str("## DCI\n\n");
out.push_str(&self.dci.to_markdown());
out.push_str("\n\n## Baseline\n\n");
out.push_str(&self.baseline.to_markdown());
out.push_str("\n\n## Delta (DCI − baseline)\n\n");
out.push_str(&self.dci.delta_markdown(&self.baseline).unwrap_or_default());
out
}
}
pub struct DciRetriever {
corpus: CorpusRoot,
name: String,
}
impl DciRetriever {
pub fn new(corpus: CorpusRoot) -> Self {
Self {
corpus,
name: "dci-lexical".to_string(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
}
impl Retriever for DciRetriever {
fn name(&self) -> &str {
&self.name
}
fn retrieve<'a>(&'a self, query: &'a str, k: usize) -> RetrieveFuture<'a> {
let terms = tokenize(query);
let corpus = self.corpus.clone();
Box::pin(async move {
if terms.is_empty() {
return Ok(Vec::new());
}
let max_results = corpus.limits().max_results;
let joined = tokio::task::spawn_blocking(move || {
let pattern = terms
.iter()
.map(|t| regex_escape(t))
.collect::<Vec<_>>()
.join("|");
let result = engine::search(
&corpus,
&SearchQuery {
pattern,
path_glob: None,
case_insensitive: true,
context_lines: 0,
max_results: Some(max_results),
},
)?;
let lc_terms: Vec<String> = terms.iter().map(|t| t.to_lowercase()).collect();
let mut per_file: HashMap<String, (HashSet<usize>, usize)> = HashMap::new();
for hit in &result.hits {
let entry = per_file.entry(hit.path.clone()).or_default();
entry.1 += 1;
let line_lc = hit.text.to_lowercase();
for (i, term) in lc_terms.iter().enumerate() {
if line_lc.contains(term) {
entry.0.insert(i);
}
}
}
let mut ranked: Vec<RetrievedDoc> = per_file
.into_iter()
.map(|(doc_id, (distinct, total))| {
let score = distinct.len() as f64 + 0.1 * (1.0 + total as f64).ln();
RetrievedDoc { doc_id, score }
})
.collect();
ranked.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.doc_id.cmp(&b.doc_id))
});
ranked.truncate(k);
Ok::<_, DciError>(ranked)
})
.await
.map_err(|e| rig_retrieval_evals::Error::Config(e.to_string()))?;
joined.map_err(|e| rig_retrieval_evals::Error::Config(e.to_string()))
})
}
}
const STOPWORDS: &[&str] = &[
"the", "a", "an", "of", "to", "in", "is", "are", "was", "were", "and", "or", "for", "on", "at",
"by", "with", "as", "that", "this", "it", "be", "from", "who", "what", "when", "where",
"which", "how", "did", "do", "does",
];
fn tokenize(query: &str) -> Vec<String> {
let mut seen = HashSet::new();
query
.split_whitespace()
.map(|tok| {
tok.trim_matches(|c: char| !c.is_alphanumeric())
.to_lowercase()
})
.filter(|tok| tok.len() >= 2 && !STOPWORDS.contains(&tok.as_str()))
.filter(|tok| seen.insert(tok.clone()))
.collect()
}
fn regex_escape(term: &str) -> String {
let mut out = String::with_capacity(term.len() * 2);
for ch in term.chars() {
if !ch.is_alphanumeric() {
out.push('\\');
}
out.push(ch);
}
out
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic
)]
use super::*;
#[test]
fn tokenize_drops_stopwords_and_punctuation() {
let terms = tokenize("Who wrote 1984?");
assert_eq!(terms, vec!["wrote", "1984"]);
}
#[test]
fn tokenize_preserves_internal_punctuation() {
let terms = tokenize("contact from 10.0.0.5 please");
assert!(terms.contains(&"10.0.0.5".to_string()));
}
#[test]
fn regex_escape_neutralizes_metachars() {
assert_eq!(regex_escape("10.0.0.5"), "10\\.0\\.0\\.5");
assert_eq!(regex_escape("a+b"), "a\\+b");
assert_eq!(regex_escape("abc"), "abc");
}
}