Skip to main content

brainwires_rag/rag/client/
reranking.rs

1//! Diversity / relevance reranking via pluggable reranker strategies for [`RagClient`].
2//!
3//! Requires the `spectral` feature.
4
5use super::RagClient;
6use crate::rag::types::*;
7use anyhow::{Context, Result};
8use std::time::Instant;
9
10impl RagClient {
11    /// Query the indexed codebase with pluggable diversity/relevance reranking.
12    ///
13    /// This oversamples candidates (3× the limit), then applies the chosen
14    /// reranker to select the final result set.  Pass `None` to use the default
15    /// spectral reranker with its default configuration.
16    ///
17    /// ## Reranker options
18    ///
19    /// - [`RerankerKind::Spectral`](crate::spectral::RerankerKind::Spectral) — greedy log-det maximization (diversity-focused)
20    /// - [`RerankerKind::CrossEncoder`](crate::spectral::RerankerKind::CrossEncoder) — query-document cosine blend (relevance-focused)
21    /// - [`RerankerKind::Both`](crate::spectral::RerankerKind::Both) — spectral first, then cross-encoder on the selected subset
22    ///
23    /// Requires the `spectral` feature.
24    pub async fn query_diverse(
25        &self,
26        request: QueryRequest,
27        reranker: Option<crate::spectral::RerankerKind>,
28    ) -> Result<QueryResponse> {
29        use crate::spectral::{
30            CrossEncoderReranker, DiversityReranker, RerankerKind, SpectralReranker,
31        };
32
33        request.validate().map_err(|e| anyhow::anyhow!(e))?;
34        self.check_path_not_dirty(request.path.as_deref()).await?;
35
36        let start = Instant::now();
37
38        // Determine final_k from the reranker config or the request limit.
39        let final_k = match &reranker {
40            Some(RerankerKind::Spectral(cfg)) => cfg.k.unwrap_or(request.limit),
41            Some(RerankerKind::Both { spectral, .. }) => spectral.k.unwrap_or(request.limit),
42            _ => request.limit,
43        };
44
45        // Oversample: retrieve 3× candidates for the reranker to select from.
46        let oversample_limit = final_k * 3;
47
48        let query_embedding = self
49            .embedding_provider
50            .embed(&request.query)
51            .context("Failed to generate query embedding")?;
52
53        let original_threshold = request.min_score;
54        let mut threshold_used = original_threshold;
55        let mut threshold_lowered = false;
56
57        // Search with embeddings so we can pass them to the reranker.
58        let (mut candidates, mut embeddings) = self
59            .vector_db
60            .search_with_embeddings(
61                query_embedding.clone(),
62                &request.query,
63                oversample_limit,
64                threshold_used,
65                request.project.clone(),
66                request.path.clone(),
67                request.hybrid,
68            )
69            .await
70            .context("Failed to search with embeddings")?;
71
72        // Adaptive threshold lowering if no results.
73        if candidates.is_empty() && original_threshold > 0.3 {
74            let fallback_thresholds = [0.6, 0.5, 0.4, 0.3];
75            for &threshold in &fallback_thresholds {
76                if threshold >= original_threshold {
77                    continue;
78                }
79                let (c, e) = self
80                    .vector_db
81                    .search_with_embeddings(
82                        query_embedding.clone(),
83                        &request.query,
84                        oversample_limit,
85                        threshold,
86                        request.project.clone(),
87                        request.path.clone(),
88                        request.hybrid,
89                    )
90                    .await
91                    .context("Failed to search with embeddings")?;
92                if !c.is_empty() {
93                    candidates = c;
94                    embeddings = e;
95                    threshold_used = threshold;
96                    threshold_lowered = true;
97                    break;
98                }
99            }
100        }
101
102        let has_enough = candidates.len() > final_k && embeddings.iter().all(|e| !e.is_empty());
103
104        let results = if has_enough {
105            match reranker {
106                None | Some(RerankerKind::Spectral(_)) => {
107                    let spectral_cfg = match reranker {
108                        Some(RerankerKind::Spectral(cfg)) => cfg,
109                        _ => crate::spectral::SpectralSelectConfig::default(),
110                    };
111                    if candidates.len() >= spectral_cfg.min_candidates {
112                        let r = SpectralReranker::new(spectral_cfg);
113                        let indices = r.rerank(&candidates, &embeddings, final_k);
114                        indices.into_iter().map(|i| candidates[i].clone()).collect()
115                    } else {
116                        candidates.truncate(final_k);
117                        candidates
118                    }
119                }
120                Some(RerankerKind::CrossEncoder(mut ce_cfg)) => {
121                    // Inject query embedding if caller left it empty.
122                    if ce_cfg.query_embedding.is_empty() {
123                        ce_cfg.query_embedding = query_embedding.clone();
124                    }
125                    let r = CrossEncoderReranker::new(ce_cfg);
126                    let indices = r.rerank(&candidates, &embeddings, final_k);
127                    indices.into_iter().map(|i| candidates[i].clone()).collect()
128                }
129                Some(RerankerKind::Both {
130                    spectral,
131                    mut cross_encoder,
132                }) => {
133                    // Pass 1: spectral diversity selection.
134                    let spectral_k = spectral.k.unwrap_or(final_k * 2).max(final_k);
135                    let indices1 = if candidates.len() >= spectral.min_candidates {
136                        let r = SpectralReranker::new(spectral);
137                        r.rerank(&candidates, &embeddings, spectral_k)
138                    } else {
139                        (0..candidates.len().min(spectral_k)).collect()
140                    };
141
142                    // Build intermediate candidate/embedding slices.
143                    let mid_candidates: Vec<_> =
144                        indices1.iter().map(|&i| candidates[i].clone()).collect();
145                    let mid_embeddings: Vec<_> =
146                        indices1.iter().map(|&i| embeddings[i].clone()).collect();
147
148                    // Pass 2: cross-encoder relevance ordering.
149                    if cross_encoder.query_embedding.is_empty() {
150                        cross_encoder.query_embedding = query_embedding.clone();
151                    }
152                    let r = CrossEncoderReranker::new(cross_encoder);
153                    let indices2 = r.rerank(&mid_candidates, &mid_embeddings, final_k);
154                    indices2
155                        .into_iter()
156                        .map(|i| mid_candidates[i].clone())
157                        .collect()
158                }
159            }
160        } else {
161            candidates.truncate(final_k);
162            candidates
163        };
164
165        Ok(QueryResponse {
166            results,
167            duration_ms: start.elapsed().as_millis() as u64,
168            threshold_used,
169            threshold_lowered,
170        })
171    }
172}