use rig::vector_store::VectorStoreIndexDyn;
use tracing::instrument;
use crate::dataset::{Qrels, RetrievedSet};
use crate::error::{Error, Result};
use crate::report::MultiReport;
use crate::retrieval::RetrievalMetric;
use crate::retriever::{VectorStoreRetriever, retrieve_all, score_retriever};
pub struct RetrievalHarness<'s> {
store: &'s dyn VectorStoreIndexDyn,
k: usize,
concurrency: usize,
bootstrap: Option<(usize, f64, u64)>,
}
impl<'s> RetrievalHarness<'s> {
pub fn new(store: &'s dyn VectorStoreIndexDyn, k: usize) -> Self {
Self {
store,
k,
concurrency: 1,
bootstrap: None,
}
}
#[must_use]
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = concurrency.max(1);
self
}
#[must_use]
pub fn with_bootstrap(mut self, iterations: usize, level: f64, seed: u64) -> Self {
self.bootstrap = if iterations == 0 {
None
} else {
Some((iterations, level, seed))
};
self
}
#[must_use]
pub fn k(&self) -> usize {
self.k
}
#[instrument(skip_all, fields(evals.k = self.k, evals.queries = qrels.len(), evals.metrics = metrics.len()))]
pub async fn run(
&self,
qrels: &Qrels,
metrics: &[Box<dyn RetrievalMetric>],
) -> Result<MultiReport> {
if self.k == 0 {
return Err(Error::Config("top-k must be > 0".into()));
}
let retriever = VectorStoreRetriever::new(self.store, "vector-store");
let report = score_retriever(&retriever, qrels, self.k, metrics, self.concurrency).await?;
Ok(match self.bootstrap {
Some((iterations, level, seed)) => report.with_bootstrap(iterations, level, seed),
None => report,
})
}
pub async fn retrieve_all(&self, qrels: &Qrels) -> Result<Vec<RetrievedSet>> {
let retriever = VectorStoreRetriever::new(self.store, "vector-store");
retrieve_all(&retriever, qrels, self.k, self.concurrency).await
}
}