brainwires_rag/rag/client/
reranking.rs1use super::RagClient;
6use crate::rag::types::*;
7use anyhow::{Context, Result};
8use std::time::Instant;
9
10impl RagClient {
11 pub async fn query_diverse(
25 &self,
26 request: QueryRequest,
27 reranker: Option<crate::spectral::RerankerKind>,
28 ) -> Result<QueryResponse> {
29 use crate::spectral::{
30 CrossEncoderReranker, DiversityReranker, RerankerKind, SpectralReranker,
31 };
32
33 request.validate().map_err(|e| anyhow::anyhow!(e))?;
34 self.check_path_not_dirty(request.path.as_deref()).await?;
35
36 let start = Instant::now();
37
38 let final_k = match &reranker {
40 Some(RerankerKind::Spectral(cfg)) => cfg.k.unwrap_or(request.limit),
41 Some(RerankerKind::Both { spectral, .. }) => spectral.k.unwrap_or(request.limit),
42 _ => request.limit,
43 };
44
45 let oversample_limit = final_k * 3;
47
48 let query_embedding = self
49 .embedding_provider
50 .embed(&request.query)
51 .context("Failed to generate query embedding")?;
52
53 let original_threshold = request.min_score;
54 let mut threshold_used = original_threshold;
55 let mut threshold_lowered = false;
56
57 let (mut candidates, mut embeddings) = self
59 .vector_db
60 .search_with_embeddings(
61 query_embedding.clone(),
62 &request.query,
63 oversample_limit,
64 threshold_used,
65 request.project.clone(),
66 request.path.clone(),
67 request.hybrid,
68 )
69 .await
70 .context("Failed to search with embeddings")?;
71
72 if candidates.is_empty() && original_threshold > 0.3 {
74 let fallback_thresholds = [0.6, 0.5, 0.4, 0.3];
75 for &threshold in &fallback_thresholds {
76 if threshold >= original_threshold {
77 continue;
78 }
79 let (c, e) = self
80 .vector_db
81 .search_with_embeddings(
82 query_embedding.clone(),
83 &request.query,
84 oversample_limit,
85 threshold,
86 request.project.clone(),
87 request.path.clone(),
88 request.hybrid,
89 )
90 .await
91 .context("Failed to search with embeddings")?;
92 if !c.is_empty() {
93 candidates = c;
94 embeddings = e;
95 threshold_used = threshold;
96 threshold_lowered = true;
97 break;
98 }
99 }
100 }
101
102 let has_enough = candidates.len() > final_k && embeddings.iter().all(|e| !e.is_empty());
103
104 let results = if has_enough {
105 match reranker {
106 None | Some(RerankerKind::Spectral(_)) => {
107 let spectral_cfg = match reranker {
108 Some(RerankerKind::Spectral(cfg)) => cfg,
109 _ => crate::spectral::SpectralSelectConfig::default(),
110 };
111 if candidates.len() >= spectral_cfg.min_candidates {
112 let r = SpectralReranker::new(spectral_cfg);
113 let indices = r.rerank(&candidates, &embeddings, final_k);
114 indices.into_iter().map(|i| candidates[i].clone()).collect()
115 } else {
116 candidates.truncate(final_k);
117 candidates
118 }
119 }
120 Some(RerankerKind::CrossEncoder(mut ce_cfg)) => {
121 if ce_cfg.query_embedding.is_empty() {
123 ce_cfg.query_embedding = query_embedding.clone();
124 }
125 let r = CrossEncoderReranker::new(ce_cfg);
126 let indices = r.rerank(&candidates, &embeddings, final_k);
127 indices.into_iter().map(|i| candidates[i].clone()).collect()
128 }
129 Some(RerankerKind::Both {
130 spectral,
131 mut cross_encoder,
132 }) => {
133 let spectral_k = spectral.k.unwrap_or(final_k * 2).max(final_k);
135 let indices1 = if candidates.len() >= spectral.min_candidates {
136 let r = SpectralReranker::new(spectral);
137 r.rerank(&candidates, &embeddings, spectral_k)
138 } else {
139 (0..candidates.len().min(spectral_k)).collect()
140 };
141
142 let mid_candidates: Vec<_> =
144 indices1.iter().map(|&i| candidates[i].clone()).collect();
145 let mid_embeddings: Vec<_> =
146 indices1.iter().map(|&i| embeddings[i].clone()).collect();
147
148 if cross_encoder.query_embedding.is_empty() {
150 cross_encoder.query_embedding = query_embedding.clone();
151 }
152 let r = CrossEncoderReranker::new(cross_encoder);
153 let indices2 = r.rerank(&mid_candidates, &mid_embeddings, final_k);
154 indices2
155 .into_iter()
156 .map(|i| mid_candidates[i].clone())
157 .collect()
158 }
159 }
160 } else {
161 candidates.truncate(final_k);
162 candidates
163 };
164
165 Ok(QueryResponse {
166 results,
167 duration_ms: start.elapsed().as_millis() as u64,
168 threshold_used,
169 threshold_lowered,
170 })
171 }
172}