use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use futures::stream::{self, StreamExt};
use rig::vector_store::{VectorSearchRequest, VectorStoreIndexDyn, request::Filter};
use crate::dataset::{Qrels, RetrievedDoc, RetrievedSet};
use crate::error::{Error, Result};
use crate::report::{MetricReport, MultiReport};
use crate::retrieval::RetrievalMetric;
pub type RetrieveFuture<'a> = Pin<Box<dyn Future<Output = Result<Vec<RetrievedDoc>>> + Send + 'a>>;
pub trait Retriever: Send + Sync {
fn name(&self) -> &str;
fn retrieve<'a>(&'a self, query: &'a str, k: usize) -> RetrieveFuture<'a>;
}
pub struct VectorStoreRetriever<'s> {
store: &'s dyn VectorStoreIndexDyn,
name: String,
}
impl<'s> VectorStoreRetriever<'s> {
pub fn new(store: &'s dyn VectorStoreIndexDyn, name: impl Into<String>) -> Self {
Self {
store,
name: name.into(),
}
}
}
impl Retriever for VectorStoreRetriever<'_> {
fn name(&self) -> &str {
&self.name
}
fn retrieve<'a>(&'a self, query: &'a str, k: usize) -> RetrieveFuture<'a> {
Box::pin(async move {
let req: VectorSearchRequest<Filter<serde_json::Value>> =
VectorSearchRequest::builder()
.query(query.to_string())
.samples(k as u64)
.build();
let hits = self.store.top_n_ids(req).await?;
Ok(hits
.into_iter()
.map(|(score, doc_id)| RetrievedDoc { doc_id, score })
.collect())
})
}
}
pub async fn retrieve_all(
retriever: &dyn Retriever,
qrels: &Qrels,
k: usize,
concurrency: usize,
) -> Result<Vec<RetrievedSet>> {
let results: Vec<Result<RetrievedSet>> =
stream::iter(qrels.queries.iter().map(|q| async move {
let ranked = retriever.retrieve(&q.query, k).await?;
Ok(RetrievedSet {
query_id: q.query_id.clone(),
ranked,
})
}))
.buffered(concurrency.max(1))
.collect()
.await;
results.into_iter().collect()
}
pub async fn score_retriever(
retriever: &dyn Retriever,
qrels: &Qrels,
k: usize,
metrics: &[Box<dyn RetrievalMetric>],
concurrency: usize,
) -> Result<MultiReport> {
if k == 0 {
return Err(Error::Config("top-k must be > 0".into()));
}
let retrievals = retrieve_all(retriever, qrels, k, concurrency).await?;
let by_query: HashMap<&str, &RetrievedSet> = retrievals
.iter()
.map(|r| (r.query_id.as_str(), r))
.collect();
let mut reports: Vec<MetricReport> = Vec::with_capacity(metrics.len());
for metric in metrics {
let name = metric.name();
let mut per_query = Vec::with_capacity(qrels.queries.len());
for q in &qrels.queries {
let Some(retrieved) = by_query.get(q.query_id.as_str()) else {
continue;
};
per_query.push((q.query_id.clone(), metric.score(q, retrieved)));
}
reports.push(MetricReport::from_per_query(name, per_query));
}
Ok(MultiReport::new(reports))
}