#![cfg(feature = "eval")]
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic
)]
use std::path::PathBuf;
use dci_tool::CorpusRoot;
use dci_tool::eval::synth::{self, SyntheticLogConfig};
use dci_tool::eval::{Comparison, DciRetriever, EvalConfig, Retriever, VectorRetriever};
use rig_core::embeddings::{Embedding, EmbeddingError, EmbeddingModel};
use rig_retrieval_evals::dataset::{GoldQuery, Qrels};
#[derive(Clone, Debug, Default)]
struct HashingEmbeddingModel;
impl HashingEmbeddingModel {
const DIMS: usize = 64;
fn embed_one(text: &str) -> Vec<f64> {
let mut v = vec![0.0f64; Self::DIMS];
for token in text
.split(|c: char| !c.is_alphanumeric())
.filter(|t| t.len() >= 2)
{
let mut h: u64 = 1469598103934665603;
for b in token.to_lowercase().bytes() {
h ^= b as u64;
h = h.wrapping_mul(1099511628211);
}
let idx = (h as usize) % Self::DIMS;
v[idx] += 1.0;
}
let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
v
}
}
impl EmbeddingModel for HashingEmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
type Client = ();
fn make(_client: &Self::Client, _model: impl Into<String>, _dims: Option<usize>) -> Self {
Self
}
fn ndims(&self) -> usize {
Self::DIMS
}
async fn embed_texts(
&self,
texts: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<Embedding>, EmbeddingError> {
Ok(texts
.into_iter()
.map(|document| {
let vec = Self::embed_one(&document);
Embedding { document, vec }
})
.collect())
}
}
fn fixtures() -> CorpusRoot {
let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
CorpusRoot::new(dir).expect("fixture corpus")
}
#[tokio::test]
async fn multi_term_query_ranks_by_distinct_terms_in_one_walk() {
let dci = DciRetriever::new(fixtures());
let ranked = dci.retrieve("alice shadow", 10).await.expect("retrieve");
let pos = |needle: &str| ranked.iter().position(|d| d.doc_id.ends_with(needle));
let auth = pos("auth.log").expect("auth.log present");
let notes = pos("notes.md").expect("notes.md present");
assert!(
auth < notes,
"auth.log (2 terms) should outrank notes.md (1 term): {ranked:?}"
);
}
#[tokio::test]
async fn dci_retriever_scores_perfectly_on_exact_tokens() {
let mut q1 = GoldQuery {
query_id: "q-ip".to_string(),
query: "203.0.113.7".to_string(),
relevant_docs: Default::default(),
reference_answer: None,
};
q1.relevant_docs.insert("logs/auth.log".to_string(), 1);
q1.relevant_docs.insert("notes.md".to_string(), 1);
let mut q2 = GoldQuery {
query_id: "q-shadow".to_string(),
query: "shadow".to_string(),
relevant_docs: Default::default(),
reference_answer: None,
};
q2.relevant_docs.insert("logs/auth.log".to_string(), 1);
let qrels = Qrels {
queries: vec![q1, q2],
};
let dci = DciRetriever::new(fixtures());
let report = dci_tool::eval::evaluate(&dci, &qrels, &EvalConfig::default())
.await
.expect("evaluate");
let recall = report
.metrics
.iter()
.find(|m| m.metric.starts_with("recall"))
.expect("recall metric");
assert!(
recall.mean >= 0.99,
"DCI should fully recall exact tokens, got {}",
recall.mean
);
}
#[tokio::test]
async fn synthetic_needles_are_fully_recovered_by_dci() {
let dir = tempfile::tempdir().unwrap();
let cfg = SyntheticLogConfig {
files: 6,
lines_per_file: 300,
needles: 10,
seed: 11,
};
let qrels = synth::generate(dir.path(), &cfg).expect("generate");
let corpus = CorpusRoot::new(dir.path()).expect("corpus");
let dci = DciRetriever::new(corpus);
let report = dci_tool::eval::evaluate(&dci, &qrels, &EvalConfig::default())
.await
.expect("evaluate");
let recall = report
.metrics
.iter()
.find(|m| m.metric.starts_with("recall"))
.expect("recall");
assert_eq!(recall.mean, 1.0, "all needles should be recovered");
assert_eq!(report.metrics.len(), 6, "all six IR metrics reported");
}
#[tokio::test]
async fn head_to_head_dci_vs_vector_baseline_runs() {
let dir = tempfile::tempdir().unwrap();
let cfg = SyntheticLogConfig {
files: 4,
lines_per_file: 120,
needles: 8,
seed: 5,
};
let qrels = synth::generate(dir.path(), &cfg).expect("generate");
let corpus = CorpusRoot::new(dir.path()).expect("corpus");
let dci = DciRetriever::new(corpus.clone());
let baseline = VectorRetriever::build(&corpus, HashingEmbeddingModel)
.await
.expect("build vector baseline");
let cmp = Comparison::run(&dci, &baseline, &qrels, &EvalConfig::default())
.await
.expect("comparison");
assert_eq!(cmp.dci.metrics.len(), 6);
assert_eq!(cmp.baseline.metrics.len(), 6);
let md = cmp.to_markdown();
assert!(md.contains("Delta (DCI"));
let dci_recall = cmp
.dci
.metrics
.iter()
.find(|m| m.metric.starts_with("recall"))
.unwrap()
.mean;
let base_recall = cmp
.baseline
.metrics
.iter()
.find(|m| m.metric.starts_with("recall"))
.unwrap()
.mean;
assert!(
dci_recall >= base_recall,
"DCI recall {dci_recall} should be >= vector baseline {base_recall}"
);
}