Skip to main content

cognis_rag/retrievers/
ensemble.rs

1//! Ensemble retriever — combines multiple retrievers via Reciprocal Rank
2//! Fusion.
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use futures::future::join_all;
9
10use cognis_core::{Result, Runnable, RunnableConfig};
11
12use crate::document::Document;
13
14/// One retriever inside an [`EnsembleRetriever`] paired with its weight.
15type WeightedRetriever = (Arc<dyn Runnable<String, Vec<Document>>>, f32);
16
17/// Combines results from N retrievers using Reciprocal Rank Fusion (RRF).
18///
19/// For each retriever's ranked list, each document earns a score of
20/// `weight / (rrf_k + rank)`. Scores across retrievers sum; final results
21/// are sorted descending and truncated to `top_k`.
22///
23/// Common pattern: combine a [`super::VectorRetriever`] (semantic) with a
24/// [`super::BM25Retriever`] (lexical) to capture both meaning and exact-term
25/// matches.
26pub struct EnsembleRetriever {
27    retrievers: Vec<WeightedRetriever>,
28    top_k: usize,
29    rrf_k: f32,
30}
31
32impl EnsembleRetriever {
33    /// Build an empty ensemble.
34    pub fn new() -> Self {
35        Self {
36            retrievers: Vec::new(),
37            top_k: 4,
38            rrf_k: 60.0,
39        }
40    }
41
42    /// Add a retriever with a contribution weight (default 1.0).
43    pub fn with_retriever(
44        mut self,
45        retriever: Arc<dyn Runnable<String, Vec<Document>>>,
46        weight: f32,
47    ) -> Self {
48        self.retrievers.push((retriever, weight));
49        self
50    }
51
52    /// Override the final top-k.
53    pub fn with_top_k(mut self, k: usize) -> Self {
54        self.top_k = k;
55        self
56    }
57
58    /// Override the RRF `k` constant (default 60.0 — RRF paper).
59    pub fn with_rrf_k(mut self, k: f32) -> Self {
60        self.rrf_k = k;
61        self
62    }
63}
64
65impl Default for EnsembleRetriever {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71#[async_trait]
72impl Runnable<String, Vec<Document>> for EnsembleRetriever {
73    async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
74        let calls = self.retrievers.iter().map(|(r, w)| {
75            let r = r.clone();
76            let q = query.clone();
77            let cfg = config.clone();
78            let weight = *w;
79            async move {
80                let docs = r.invoke(q, cfg).await?;
81                Ok::<(Vec<Document>, f32), cognis_core::CognisError>((docs, weight))
82            }
83        });
84        let lists = join_all(calls)
85            .await
86            .into_iter()
87            .collect::<Result<Vec<_>>>()?;
88
89        // Fuse via RRF.
90        let mut scores: HashMap<String, (f32, Document)> = HashMap::new();
91        for (docs, weight) in lists {
92            for (rank, doc) in docs.into_iter().enumerate() {
93                let key = doc.id.clone().unwrap_or_else(|| doc.content.clone());
94                let contribution = weight / (self.rrf_k + rank as f32 + 1.0);
95                scores
96                    .entry(key)
97                    .and_modify(|(s, _)| *s += contribution)
98                    .or_insert((contribution, doc));
99            }
100        }
101
102        let mut all: Vec<(String, f32, Document)> =
103            scores.into_iter().map(|(k, (s, d))| (k, s, d)).collect();
104        all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
105        Ok(all
106            .into_iter()
107            .take(self.top_k)
108            .map(|(_, _, d)| d)
109            .collect())
110    }
111
112    fn name(&self) -> &str {
113        "EnsembleRetriever"
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    struct StaticRetriever(Vec<Document>);
122
123    #[async_trait]
124    impl Runnable<String, Vec<Document>> for StaticRetriever {
125        async fn invoke(&self, _q: String, _: RunnableConfig) -> Result<Vec<Document>> {
126            Ok(self.0.clone())
127        }
128    }
129
130    #[tokio::test]
131    async fn fuses_two_retrievers_by_rrf() {
132        let r1: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(StaticRetriever(vec![
133            Document::new("a").with_id("a"),
134            Document::new("b").with_id("b"),
135            Document::new("c").with_id("c"),
136        ]));
137        let r2: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(StaticRetriever(vec![
138            Document::new("c").with_id("c"),
139            Document::new("a").with_id("a"),
140            Document::new("d").with_id("d"),
141        ]));
142
143        let ens = EnsembleRetriever::new()
144            .with_retriever(r1, 1.0)
145            .with_retriever(r2, 1.0)
146            .with_top_k(3);
147        let out = ens
148            .invoke("query".into(), RunnableConfig::default())
149            .await
150            .unwrap();
151        let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
152        // `a` appears in both lists at high ranks → should rank first.
153        assert_eq!(ids[0], "a");
154        // `c` also appears in both → should rank well.
155        assert!(ids.contains(&"c".to_string()));
156    }
157
158    #[tokio::test]
159    async fn empty_ensemble_returns_empty() {
160        let ens = EnsembleRetriever::new();
161        let out = ens
162            .invoke("q".into(), RunnableConfig::default())
163            .await
164            .unwrap();
165        assert!(out.is_empty());
166    }
167}