cognis_rag/retrievers/
ensemble.rs1use std::collections::HashMap;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use futures::future::join_all;
9
10use cognis_core::{Result, Runnable, RunnableConfig};
11
12use crate::document::Document;
13
14type WeightedRetriever = (Arc<dyn Runnable<String, Vec<Document>>>, f32);
16
17pub struct EnsembleRetriever {
27 retrievers: Vec<WeightedRetriever>,
28 top_k: usize,
29 rrf_k: f32,
30}
31
32impl EnsembleRetriever {
33 pub fn new() -> Self {
35 Self {
36 retrievers: Vec::new(),
37 top_k: 4,
38 rrf_k: 60.0,
39 }
40 }
41
42 pub fn with_retriever(
44 mut self,
45 retriever: Arc<dyn Runnable<String, Vec<Document>>>,
46 weight: f32,
47 ) -> Self {
48 self.retrievers.push((retriever, weight));
49 self
50 }
51
52 pub fn with_top_k(mut self, k: usize) -> Self {
54 self.top_k = k;
55 self
56 }
57
58 pub fn with_rrf_k(mut self, k: f32) -> Self {
60 self.rrf_k = k;
61 self
62 }
63}
64
65impl Default for EnsembleRetriever {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71#[async_trait]
72impl Runnable<String, Vec<Document>> for EnsembleRetriever {
73 async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
74 let calls = self.retrievers.iter().map(|(r, w)| {
75 let r = r.clone();
76 let q = query.clone();
77 let cfg = config.clone();
78 let weight = *w;
79 async move {
80 let docs = r.invoke(q, cfg).await?;
81 Ok::<(Vec<Document>, f32), cognis_core::CognisError>((docs, weight))
82 }
83 });
84 let lists = join_all(calls)
85 .await
86 .into_iter()
87 .collect::<Result<Vec<_>>>()?;
88
89 let mut scores: HashMap<String, (f32, Document)> = HashMap::new();
91 for (docs, weight) in lists {
92 for (rank, doc) in docs.into_iter().enumerate() {
93 let key = doc.id.clone().unwrap_or_else(|| doc.content.clone());
94 let contribution = weight / (self.rrf_k + rank as f32 + 1.0);
95 scores
96 .entry(key)
97 .and_modify(|(s, _)| *s += contribution)
98 .or_insert((contribution, doc));
99 }
100 }
101
102 let mut all: Vec<(String, f32, Document)> =
103 scores.into_iter().map(|(k, (s, d))| (k, s, d)).collect();
104 all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
105 Ok(all
106 .into_iter()
107 .take(self.top_k)
108 .map(|(_, _, d)| d)
109 .collect())
110 }
111
112 fn name(&self) -> &str {
113 "EnsembleRetriever"
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 struct StaticRetriever(Vec<Document>);
122
123 #[async_trait]
124 impl Runnable<String, Vec<Document>> for StaticRetriever {
125 async fn invoke(&self, _q: String, _: RunnableConfig) -> Result<Vec<Document>> {
126 Ok(self.0.clone())
127 }
128 }
129
130 #[tokio::test]
131 async fn fuses_two_retrievers_by_rrf() {
132 let r1: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(StaticRetriever(vec![
133 Document::new("a").with_id("a"),
134 Document::new("b").with_id("b"),
135 Document::new("c").with_id("c"),
136 ]));
137 let r2: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(StaticRetriever(vec![
138 Document::new("c").with_id("c"),
139 Document::new("a").with_id("a"),
140 Document::new("d").with_id("d"),
141 ]));
142
143 let ens = EnsembleRetriever::new()
144 .with_retriever(r1, 1.0)
145 .with_retriever(r2, 1.0)
146 .with_top_k(3);
147 let out = ens
148 .invoke("query".into(), RunnableConfig::default())
149 .await
150 .unwrap();
151 let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
152 assert_eq!(ids[0], "a");
154 assert!(ids.contains(&"c".to_string()));
156 }
157
158 #[tokio::test]
159 async fn empty_ensemble_returns_empty() {
160 let ens = EnsembleRetriever::new();
161 let out = ens
162 .invoke("q".into(), RunnableConfig::default())
163 .await
164 .unwrap();
165 assert!(out.is_empty());
166 }
167}