Skip to main content

brainwires_rag/rag/client/
ensemble.rs

1//! Multi-strategy ensemble query (Reciprocal Rank Fusion) for [`RagClient`].
2
3use super::RagClient;
4use crate::rag::types::*;
5use anyhow::{Context, Result};
6use std::collections::HashMap;
7use std::time::Instant;
8
9impl RagClient {
10    /// Multi-strategy ensemble query: fan out across all requested strategies
11    /// concurrently, fuse results via Reciprocal Rank Fusion (RRF), and
12    /// optionally apply spectral diversity reranking as a final pass.
13    ///
14    /// ## Strategies
15    ///
16    /// - `Semantic` — vector similarity search
17    /// - `Keyword` — BM25 keyword / hybrid search
18    /// - `GitHistory` — semantic search over commit history
19    /// - `CodeNavigation` — AST-based relations search (requires `code-analysis`)
20    ///
21    /// ## Fusion
22    ///
23    /// Results from each strategy are deduplicated by `file_path:start_line` and
24    /// fused using RRF so that items appearing near the top of multiple strategy
25    /// lists rank highest overall.
26    pub async fn query_ensemble(&self, request: EnsembleRequest) -> Result<EnsembleResponse> {
27        use brainwires_storage::bm25_search::reciprocal_rank_fusion_generic;
28
29        let start = Instant::now();
30
31        // Determine active strategies.
32        let active: Vec<SearchStrategy> = if request.strategies.is_empty() {
33            #[allow(unused_mut)]
34            let mut s = vec![
35                SearchStrategy::Semantic,
36                SearchStrategy::Keyword,
37                SearchStrategy::GitHistory,
38            ];
39            s.push(SearchStrategy::CodeNavigation);
40            s
41        } else {
42            request.strategies.clone()
43        };
44
45        // Embed the query once.
46        let query_embedding = self
47            .embedding_provider
48            .embed(&request.query)
49            .context("Failed to generate query embedding for ensemble")?;
50
51        // Fan out across strategies concurrently.
52        // Each strategy returns (strategy_name, Vec<SearchResult>).
53        let path = request.path.clone();
54        let project = request.project.clone();
55        let query = request.query.clone();
56        let limit = request.limit;
57        let min_score = request.min_score;
58        let file_extensions = request.file_extensions.clone();
59        let languages = request.languages.clone();
60
61        // Build strategy futures as boxed async closures resolved concurrently.
62        let mut strategy_futures = Vec::new();
63
64        for strategy in &active {
65            match strategy {
66                SearchStrategy::Semantic => {
67                    let qe = query_embedding.clone();
68                    let q = query.clone();
69                    let pa = path.clone();
70                    let pr = project.clone();
71                    let db = self.vector_db.clone();
72                    strategy_futures.push(tokio::spawn(async move {
73                        let results = db
74                            .search(qe, &q, limit * 2, min_score, pr, pa, false)
75                            .await
76                            .unwrap_or_default();
77                        ("semantic".to_string(), results)
78                    }));
79                }
80                SearchStrategy::Keyword => {
81                    let qe = query_embedding.clone();
82                    let q = query.clone();
83                    let pa = path.clone();
84                    let pr = project.clone();
85                    let db = self.vector_db.clone();
86                    let exts = file_extensions.clone();
87                    let langs = languages.clone();
88                    strategy_futures.push(tokio::spawn(async move {
89                        let results = if exts.is_empty() && langs.is_empty() {
90                            db.search(qe, &q, limit * 2, min_score, pr, pa, true)
91                                .await
92                                .unwrap_or_default()
93                        } else {
94                            db.search_filtered(
95                                qe,
96                                &q,
97                                limit * 2,
98                                min_score,
99                                pr,
100                                pa,
101                                true,
102                                exts,
103                                langs,
104                                Vec::new(),
105                            )
106                            .await
107                            .unwrap_or_default()
108                        };
109                        ("keyword".to_string(), results)
110                    }));
111                }
112                SearchStrategy::GitHistory => {
113                    let ep = self.embedding_provider.clone();
114                    let db = self.vector_db.clone();
115                    let gc = self.git_cache.clone();
116                    let gp = self.git_cache_path.clone();
117                    let q = query.clone();
118                    let pa = path.clone().unwrap_or_else(|| ".".to_string());
119                    let pr = project.clone();
120                    strategy_futures.push(tokio::spawn(async move {
121                        use crate::rag::client::git_indexing;
122                        use brainwires_core::SearchResult;
123                        let git_req = SearchGitHistoryRequest {
124                            query: q,
125                            path: pa,
126                            project: pr,
127                            branch: None,
128                            max_commits: 200,
129                            limit: limit * 2,
130                            min_score,
131                            author: None,
132                            since: None,
133                            until: None,
134                            file_pattern: None,
135                        };
136                        let resp: SearchGitHistoryResponse =
137                            git_indexing::do_search_git_history(ep, db, gc, &gp, git_req)
138                                .await
139                                .unwrap_or(SearchGitHistoryResponse {
140                                    results: Vec::new(),
141                                    commits_indexed: 0,
142                                    total_cached_commits: 0,
143                                    duration_ms: 0,
144                                });
145                        let results: Vec<SearchResult> = resp
146                            .results
147                            .into_iter()
148                            .map(|g| SearchResult {
149                                file_path: g.commit_hash.clone(),
150                                root_path: None,
151                                content: format!("{}\n{}", g.commit_message, g.diff_snippet),
152                                score: g.score,
153                                vector_score: g.vector_score,
154                                keyword_score: g.keyword_score,
155                                start_line: 0,
156                                end_line: 0,
157                                language: "git".to_string(),
158                                project: None,
159                                indexed_at: g.commit_date,
160                            })
161                            .collect();
162                        ("git_history".to_string(), results)
163                    }));
164                }
165                SearchStrategy::CodeNavigation => {
166                    let qe = query_embedding.clone();
167                    let db = self.vector_db.clone();
168                    let q = query.clone();
169                    let pa = path.clone();
170                    let pr = project.clone();
171                    strategy_futures.push(tokio::spawn(async move {
172                        let results = db
173                            .search(qe, &q, limit * 2, min_score, pr, pa, false)
174                            .await
175                            .unwrap_or_default();
176                        ("code_navigation".to_string(), results)
177                    }));
178                }
179            }
180        }
181
182        // Collect strategy results.
183        let mut all_results: HashMap<String, SearchResult> = HashMap::new();
184        let mut strategy_lists: Vec<Vec<(String, f32)>> = Vec::new();
185        let mut strategies_used: Vec<String> = Vec::new();
186        let mut per_strategy_counts: HashMap<String, usize> = HashMap::new();
187
188        for handle in strategy_futures {
189            match handle.await {
190                Ok((name, results)) => {
191                    per_strategy_counts.insert(name.clone(), results.len());
192                    let ranked: Vec<(String, f32)> = results
193                        .iter()
194                        .map(|r| {
195                            let key = format!("{}:{}", r.file_path, r.start_line);
196                            all_results.entry(key.clone()).or_insert_with(|| r.clone());
197                            (key, r.score)
198                        })
199                        .collect();
200                    if !ranked.is_empty() {
201                        strategies_used.push(name);
202                        strategy_lists.push(ranked);
203                    }
204                }
205                Err(e) => {
206                    tracing::warn!("Ensemble strategy task failed: {e}");
207                }
208            }
209        }
210
211        // RRF fusion across all strategy ranked lists.
212        let fused: Vec<(String, f32)> = reciprocal_rank_fusion_generic(strategy_lists, limit);
213
214        // Resolve fused keys back to SearchResult, overriding score with RRF score.
215        let mut results: Vec<SearchResult> = fused
216            .into_iter()
217            .filter_map(|(key, rrf_score)| {
218                all_results.get(&key).map(|r| {
219                    let mut result = r.clone();
220                    result.score = rrf_score;
221                    result
222                })
223            })
224            .collect();
225
226        // Optional spectral reranking as a final diversity pass.
227        if request.spectral_rerank && results.len() > limit {
228            use crate::spectral::{DiversityReranker, SpectralReranker, SpectralSelectConfig};
229            let keys: Vec<String> = results
230                .iter()
231                .map(|r| format!("{}:{}", r.file_path, r.start_line))
232                .collect();
233            // Re-fetch embeddings for the fused candidates.
234            if let Ok((_, embeddings)) = self
235                .vector_db
236                .search_with_embeddings(
237                    query_embedding.clone(),
238                    &request.query,
239                    results.len(),
240                    0.0,
241                    request.project.clone(),
242                    request.path.clone(),
243                    false,
244                )
245                .await
246            {
247                // Build a key→embedding map from the re-fetched results.
248                let _ = keys; // suppress unused warning
249                if embeddings.len() == results.len() {
250                    let reranker = SpectralReranker::new(SpectralSelectConfig::default());
251                    let indices = reranker.rerank(&results, &embeddings, limit);
252                    results = indices.into_iter().map(|i| results[i].clone()).collect();
253                } else {
254                    results.truncate(limit);
255                }
256            } else {
257                results.truncate(limit);
258            }
259        }
260
261        results.truncate(limit);
262
263        Ok(EnsembleResponse {
264            results,
265            duration_ms: start.elapsed().as_millis() as u64,
266            strategies_used,
267            per_strategy_counts,
268        })
269    }
270}