use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use futures::future::join_all;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::retrievers::BaseRetriever;
pub struct EnsembleRetriever {
retrievers: Vec<Arc<dyn BaseRetriever>>,
weights: Vec<f32>,
k: usize,
rrf_k: usize,
}
impl EnsembleRetriever {
pub fn new(retrievers: Vec<Arc<dyn BaseRetriever>>) -> Self {
let n = retrievers.len();
let weight = if n > 0 { 1.0 / n as f32 } else { 0.0 };
let weights = vec![weight; n];
Self {
retrievers,
weights,
k: 4,
rrf_k: 60,
}
}
pub fn with_weights(mut self, weights: Vec<f32>) -> Self {
assert_eq!(
weights.len(),
self.retrievers.len(),
"weights length must match retrievers length"
);
self.weights = weights;
self
}
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn with_rrf_k(mut self, rrf_k: usize) -> Self {
self.rrf_k = rrf_k;
self
}
}
#[async_trait]
impl BaseRetriever for EnsembleRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>> {
let futures: Vec<_> = self
.retrievers
.iter()
.map(|r| r.get_relevant_documents(query))
.collect();
let all_results = join_all(futures).await;
let mut result_sets: Vec<Vec<Document>> = Vec::with_capacity(all_results.len());
for result in all_results {
result_sets.push(result?);
}
let scored = reciprocal_rank_fusion(&result_sets, &self.weights, self.rrf_k);
Ok(scored
.into_iter()
.take(self.k)
.map(|(doc, _)| doc)
.collect())
}
}
pub fn reciprocal_rank_fusion(
results: &[Vec<Document>],
weights: &[f32],
rrf_k: usize,
) -> Vec<(Document, f32)> {
let mut score_map: HashMap<String, (Document, f32)> = HashMap::new();
for (i, docs) in results.iter().enumerate() {
let weight = weights.get(i).copied().unwrap_or(1.0);
for (rank, doc) in docs.iter().enumerate() {
let score = weight / (rrf_k as f32 + (rank + 1) as f32);
let entry = score_map
.entry(doc.page_content.clone())
.or_insert_with(|| (doc.clone(), 0.0));
entry.1 += score;
}
}
let mut scored: Vec<(Document, f32)> = score_map.into_values().collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
}
#[cfg(test)]
mod tests {
use super::*;
struct MockRetriever {
docs: Vec<Document>,
}
impl MockRetriever {
fn new(contents: &[&str]) -> Self {
Self {
docs: contents.iter().map(|c| Document::new(*c)).collect(),
}
}
}
#[async_trait]
impl BaseRetriever for MockRetriever {
async fn get_relevant_documents(&self, _query: &str) -> Result<Vec<Document>> {
Ok(self.docs.clone())
}
}
struct EmptyRetriever;
#[async_trait]
impl BaseRetriever for EmptyRetriever {
async fn get_relevant_documents(&self, _query: &str) -> Result<Vec<Document>> {
Ok(vec![])
}
}
#[tokio::test]
async fn test_single_retriever_returns_its_results() {
let r = Arc::new(MockRetriever::new(&["doc1", "doc2", "doc3"]));
let ensemble = EnsembleRetriever::new(vec![r]).with_k(10);
let docs = ensemble.get_relevant_documents("query").await.unwrap();
assert_eq!(docs.len(), 3);
assert_eq!(docs[0].page_content, "doc1");
assert_eq!(docs[1].page_content, "doc2");
assert_eq!(docs[2].page_content, "doc3");
}
#[tokio::test]
async fn test_two_retrievers_combine_via_rrf() {
let r1 = Arc::new(MockRetriever::new(&["A", "B", "C"]));
let r2 = Arc::new(MockRetriever::new(&["D", "A", "E"]));
let ensemble = EnsembleRetriever::new(vec![r1, r2]).with_k(10);
let docs = ensemble.get_relevant_documents("query").await.unwrap();
assert_eq!(docs[0].page_content, "A");
assert_eq!(docs.len(), 5);
}
#[tokio::test]
async fn test_weighted_retrievers_favor_higher_weight() {
let r1 = Arc::new(MockRetriever::new(&["X"]));
let r2 = Arc::new(MockRetriever::new(&["Y"]));
let ensemble = EnsembleRetriever::new(vec![r1, r2])
.with_weights(vec![0.1, 0.9])
.with_k(10);
let docs = ensemble.get_relevant_documents("query").await.unwrap();
assert_eq!(docs.len(), 2);
assert_eq!(docs[0].page_content, "Y");
assert_eq!(docs[1].page_content, "X");
}
#[tokio::test]
async fn test_deduplication_across_retrievers() {
let r1 = Arc::new(MockRetriever::new(&["shared", "only_r1"]));
let r2 = Arc::new(MockRetriever::new(&["shared", "only_r2"]));
let ensemble = EnsembleRetriever::new(vec![r1, r2]).with_k(10);
let docs = ensemble.get_relevant_documents("query").await.unwrap();
let contents: Vec<&str> = docs.iter().map(|d| d.page_content.as_str()).collect();
assert_eq!(contents.iter().filter(|&&c| c == "shared").count(), 1);
assert_eq!(docs.len(), 3); }
#[tokio::test]
async fn test_custom_k_limits_output() {
let r = Arc::new(MockRetriever::new(&["a", "b", "c", "d", "e"]));
let ensemble = EnsembleRetriever::new(vec![r]).with_k(2);
let docs = ensemble.get_relevant_documents("query").await.unwrap();
assert_eq!(docs.len(), 2);
}
#[tokio::test]
async fn test_empty_results_from_one_retriever() {
let r1 = Arc::new(MockRetriever::new(&["real_doc"]));
let r2: Arc<dyn BaseRetriever> = Arc::new(EmptyRetriever);
let ensemble = EnsembleRetriever::new(vec![r1, r2]).with_k(10);
let docs = ensemble.get_relevant_documents("query").await.unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].page_content, "real_doc");
}
#[tokio::test]
async fn test_rrf_scoring_standalone() {
let docs_a = vec![Document::new("X"), Document::new("Y")];
let docs_b = vec![Document::new("Y"), Document::new("Z")];
let results = vec![docs_a, docs_b];
let weights = vec![0.5, 0.5];
let scored = reciprocal_rank_fusion(&results, &weights, 60);
assert_eq!(scored[0].0.page_content, "Y");
assert_eq!(scored.len(), 3);
}
}