brainwires_rag/rag/client/
ensemble.rs1use super::RagClient;
4use crate::rag::types::*;
5use anyhow::{Context, Result};
6use std::collections::HashMap;
7use std::time::Instant;
8
9impl RagClient {
10 pub async fn query_ensemble(&self, request: EnsembleRequest) -> Result<EnsembleResponse> {
27 use brainwires_storage::bm25_search::reciprocal_rank_fusion_generic;
28
29 let start = Instant::now();
30
31 let active: Vec<SearchStrategy> = if request.strategies.is_empty() {
33 #[allow(unused_mut)]
34 let mut s = vec![
35 SearchStrategy::Semantic,
36 SearchStrategy::Keyword,
37 SearchStrategy::GitHistory,
38 ];
39 s.push(SearchStrategy::CodeNavigation);
40 s
41 } else {
42 request.strategies.clone()
43 };
44
45 let query_embedding = self
47 .embedding_provider
48 .embed(&request.query)
49 .context("Failed to generate query embedding for ensemble")?;
50
51 let path = request.path.clone();
54 let project = request.project.clone();
55 let query = request.query.clone();
56 let limit = request.limit;
57 let min_score = request.min_score;
58 let file_extensions = request.file_extensions.clone();
59 let languages = request.languages.clone();
60
61 let mut strategy_futures = Vec::new();
63
64 for strategy in &active {
65 match strategy {
66 SearchStrategy::Semantic => {
67 let qe = query_embedding.clone();
68 let q = query.clone();
69 let pa = path.clone();
70 let pr = project.clone();
71 let db = self.vector_db.clone();
72 strategy_futures.push(tokio::spawn(async move {
73 let results = db
74 .search(qe, &q, limit * 2, min_score, pr, pa, false)
75 .await
76 .unwrap_or_default();
77 ("semantic".to_string(), results)
78 }));
79 }
80 SearchStrategy::Keyword => {
81 let qe = query_embedding.clone();
82 let q = query.clone();
83 let pa = path.clone();
84 let pr = project.clone();
85 let db = self.vector_db.clone();
86 let exts = file_extensions.clone();
87 let langs = languages.clone();
88 strategy_futures.push(tokio::spawn(async move {
89 let results = if exts.is_empty() && langs.is_empty() {
90 db.search(qe, &q, limit * 2, min_score, pr, pa, true)
91 .await
92 .unwrap_or_default()
93 } else {
94 db.search_filtered(
95 qe,
96 &q,
97 limit * 2,
98 min_score,
99 pr,
100 pa,
101 true,
102 exts,
103 langs,
104 Vec::new(),
105 )
106 .await
107 .unwrap_or_default()
108 };
109 ("keyword".to_string(), results)
110 }));
111 }
112 SearchStrategy::GitHistory => {
113 let ep = self.embedding_provider.clone();
114 let db = self.vector_db.clone();
115 let gc = self.git_cache.clone();
116 let gp = self.git_cache_path.clone();
117 let q = query.clone();
118 let pa = path.clone().unwrap_or_else(|| ".".to_string());
119 let pr = project.clone();
120 strategy_futures.push(tokio::spawn(async move {
121 use crate::rag::client::git_indexing;
122 use brainwires_core::SearchResult;
123 let git_req = SearchGitHistoryRequest {
124 query: q,
125 path: pa,
126 project: pr,
127 branch: None,
128 max_commits: 200,
129 limit: limit * 2,
130 min_score,
131 author: None,
132 since: None,
133 until: None,
134 file_pattern: None,
135 };
136 let resp: SearchGitHistoryResponse =
137 git_indexing::do_search_git_history(ep, db, gc, &gp, git_req)
138 .await
139 .unwrap_or(SearchGitHistoryResponse {
140 results: Vec::new(),
141 commits_indexed: 0,
142 total_cached_commits: 0,
143 duration_ms: 0,
144 });
145 let results: Vec<SearchResult> = resp
146 .results
147 .into_iter()
148 .map(|g| SearchResult {
149 file_path: g.commit_hash.clone(),
150 root_path: None,
151 content: format!("{}\n{}", g.commit_message, g.diff_snippet),
152 score: g.score,
153 vector_score: g.vector_score,
154 keyword_score: g.keyword_score,
155 start_line: 0,
156 end_line: 0,
157 language: "git".to_string(),
158 project: None,
159 indexed_at: g.commit_date,
160 })
161 .collect();
162 ("git_history".to_string(), results)
163 }));
164 }
165 SearchStrategy::CodeNavigation => {
166 let qe = query_embedding.clone();
167 let db = self.vector_db.clone();
168 let q = query.clone();
169 let pa = path.clone();
170 let pr = project.clone();
171 strategy_futures.push(tokio::spawn(async move {
172 let results = db
173 .search(qe, &q, limit * 2, min_score, pr, pa, false)
174 .await
175 .unwrap_or_default();
176 ("code_navigation".to_string(), results)
177 }));
178 }
179 }
180 }
181
182 let mut all_results: HashMap<String, SearchResult> = HashMap::new();
184 let mut strategy_lists: Vec<Vec<(String, f32)>> = Vec::new();
185 let mut strategies_used: Vec<String> = Vec::new();
186 let mut per_strategy_counts: HashMap<String, usize> = HashMap::new();
187
188 for handle in strategy_futures {
189 match handle.await {
190 Ok((name, results)) => {
191 per_strategy_counts.insert(name.clone(), results.len());
192 let ranked: Vec<(String, f32)> = results
193 .iter()
194 .map(|r| {
195 let key = format!("{}:{}", r.file_path, r.start_line);
196 all_results.entry(key.clone()).or_insert_with(|| r.clone());
197 (key, r.score)
198 })
199 .collect();
200 if !ranked.is_empty() {
201 strategies_used.push(name);
202 strategy_lists.push(ranked);
203 }
204 }
205 Err(e) => {
206 tracing::warn!("Ensemble strategy task failed: {e}");
207 }
208 }
209 }
210
211 let fused: Vec<(String, f32)> = reciprocal_rank_fusion_generic(strategy_lists, limit);
213
214 let mut results: Vec<SearchResult> = fused
216 .into_iter()
217 .filter_map(|(key, rrf_score)| {
218 all_results.get(&key).map(|r| {
219 let mut result = r.clone();
220 result.score = rrf_score;
221 result
222 })
223 })
224 .collect();
225
226 if request.spectral_rerank && results.len() > limit {
228 use crate::spectral::{DiversityReranker, SpectralReranker, SpectralSelectConfig};
229 let keys: Vec<String> = results
230 .iter()
231 .map(|r| format!("{}:{}", r.file_path, r.start_line))
232 .collect();
233 if let Ok((_, embeddings)) = self
235 .vector_db
236 .search_with_embeddings(
237 query_embedding.clone(),
238 &request.query,
239 results.len(),
240 0.0,
241 request.project.clone(),
242 request.path.clone(),
243 false,
244 )
245 .await
246 {
247 let _ = keys; if embeddings.len() == results.len() {
250 let reranker = SpectralReranker::new(SpectralSelectConfig::default());
251 let indices = reranker.rerank(&results, &embeddings, limit);
252 results = indices.into_iter().map(|i| results[i].clone()).collect();
253 } else {
254 results.truncate(limit);
255 }
256 } else {
257 results.truncate(limit);
258 }
259 }
260
261 results.truncate(limit);
262
263 Ok(EnsembleResponse {
264 results,
265 duration_ms: start.elapsed().as_millis() as u64,
266 strategies_used,
267 per_strategy_counts,
268 })
269 }
270}