1use std::collections::{HashMap, HashSet};
7
8use crate::core::bm25_index::{tokenize_for_index, BM25Index};
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct SpladeResult {
13 pub chunk_idx: usize,
14 pub file_path: String,
15 pub symbol_name: String,
16 pub combined_score: f64,
17 pub bm25_score: f64,
18 pub expansion_score: f64,
19}
20
21fn expansion_dictionary() -> HashMap<&'static str, Vec<(&'static str, f64)>> {
23 HashMap::from([
24 (
25 "auth",
26 vec![
27 ("authentication", 1.0),
28 ("token", 0.9),
29 ("jwt", 0.85),
30 ("login", 0.8),
31 ("session", 0.85),
32 ("oauth", 0.75),
33 ("credential", 0.7),
34 ],
35 ),
36 (
37 "async",
38 vec![
39 ("await", 1.0),
40 ("future", 0.85),
41 ("promise", 0.75),
42 ("tokio", 0.65),
43 ("concurrent", 0.7),
44 ],
45 ),
46 (
47 "error",
48 vec![
49 ("err", 0.95),
50 ("result", 0.75),
51 ("panic", 0.55),
52 ("exception", 0.65),
53 ("unwrap", 0.5),
54 ],
55 ),
56 (
57 "http",
58 vec![
59 ("request", 0.85),
60 ("response", 0.85),
61 ("rest", 0.65),
62 ("json", 0.7),
63 ("header", 0.6),
64 ],
65 ),
66 (
67 "db",
68 vec![
69 ("database", 1.0),
70 ("sql", 0.85),
71 ("query", 0.75),
72 ("transaction", 0.65),
73 ("migration", 0.55),
74 ],
75 ),
76 (
77 "test",
78 vec![
79 ("mock", 0.75),
80 ("fixture", 0.6),
81 ("assert", 0.85),
82 ("expect", 0.65),
83 ],
84 ),
85 (
86 "config",
87 vec![
88 ("configuration", 1.0),
89 ("env", 0.75),
90 ("setting", 0.65),
91 ("toml", 0.55),
92 ],
93 ),
94 (
95 "cache",
96 vec![
97 ("memo", 0.65),
98 ("redis", 0.55),
99 ("ttl", 0.6),
100 ("invalidate", 0.55),
101 ],
102 ),
103 ])
104}
105
106fn build_expanded_weights(query_tokens: &[String]) -> HashMap<String, f64> {
107 let dict = expansion_dictionary();
108 let mut out: HashMap<String, f64> = HashMap::new();
109
110 for t in query_tokens {
111 let lower = t.to_lowercase();
112 let entry = out.entry(lower.clone()).or_insert(1.0);
113 *entry = (*entry).max(1.0);
114
115 if let Some(rel) = dict.get(lower.as_str()) {
116 for (syn, w) in rel {
117 let le = out.entry((*syn).to_string()).or_insert(0.0);
118 *le = (*le).max(*w);
119 }
120 }
121 }
122 out
123}
124
125fn expansion_bm25_for_chunk(
126 index: &BM25Index,
127 chunk_idx: usize,
128 expanded: &HashMap<String, f64>,
129 original_terms: &HashSet<String>,
130) -> f64 {
131 if index.doc_count == 0 {
132 return 0.0;
133 }
134
135 let doc_len = index.chunks[chunk_idx].token_count as f64;
136 let norm_len = doc_len / index.avg_doc_len.max(1.0);
137
138 const K1: f64 = 1.2;
139 const B: f64 = 0.75;
140
141 let mut sum = 0.0;
142 for (term, ew) in expanded {
143 if original_terms.contains(term) {
144 continue;
145 }
146 let df = *index.doc_freqs.get(term).unwrap_or(&0) as f64;
147 if df == 0.0 {
148 continue;
149 }
150
151 let idf = ((index.doc_count as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
152 let tf = index.inverted.get(term).map_or(0.0, |postings| {
153 postings
154 .iter()
155 .filter(|(idx, _)| *idx == chunk_idx)
156 .map(|(_, w)| *w)
157 .sum::<f64>()
158 });
159
160 if tf == 0.0 {
161 continue;
162 }
163
164 let bm25_t = idf * (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * norm_len));
165 sum += ew * bm25_t;
166 }
167 sum
168}
169
170pub fn hybrid_retrieve(query: &str, bm25_index: &BM25Index, top_k: usize) -> Vec<SpladeResult> {
172 if bm25_index.doc_count == 0 || top_k == 0 {
173 return Vec::new();
174 }
175
176 let query_tokens = tokenize_for_index(query);
177 if query_tokens.is_empty() {
178 return Vec::new();
179 }
180
181 let original_terms: HashSet<String> = query_tokens.iter().map(|s| s.to_lowercase()).collect();
182
183 let expanded = build_expanded_weights(&query_tokens);
184
185 let stage1 = bm25_index.search(query, 100.min(bm25_index.doc_count.max(1)));
186
187 let bm25_by_chunk: HashMap<usize, f64> =
188 stage1.iter().map(|r| (r.chunk_idx, r.score)).collect();
189
190 let chunk_indices: Vec<usize> = bm25_by_chunk.keys().copied().collect();
191 if chunk_indices.is_empty() {
192 return Vec::new();
193 }
194
195 let mut expansion_scores: HashMap<usize, f64> = HashMap::new();
196 for &idx in &chunk_indices {
197 let es = expansion_bm25_for_chunk(bm25_index, idx, &expanded, &original_terms);
198 expansion_scores.insert(idx, es);
199 }
200
201 let max_bm25 = bm25_by_chunk.values().copied().fold(0.0_f64, f64::max);
202 let max_exp = expansion_scores.values().copied().fold(0.0_f64, f64::max);
203
204 let norm_bm25 = |s: f64| -> f64 {
205 if max_bm25 > 1e-12 {
206 s / max_bm25
207 } else {
208 0.0
209 }
210 };
211 let norm_exp = |s: f64| -> f64 {
212 if max_exp > 1e-12 {
213 s / max_exp
214 } else {
215 0.0
216 }
217 };
218
219 const W_BM25: f64 = 0.65;
220 const W_EXP: f64 = 0.35;
221
222 let mut results: Vec<SpladeResult> = chunk_indices
223 .into_iter()
224 .map(|chunk_idx| {
225 let chunk = &bm25_index.chunks[chunk_idx];
226 let bm25_score = bm25_by_chunk.get(&chunk_idx).copied().unwrap_or(0.0);
227 let expansion_score = expansion_scores.get(&chunk_idx).copied().unwrap_or(0.0);
228 let combined_score = W_BM25 * norm_bm25(bm25_score) + W_EXP * norm_exp(expansion_score);
229
230 SpladeResult {
231 chunk_idx,
232 file_path: chunk.file_path.clone(),
233 symbol_name: chunk.symbol_name.clone(),
234 combined_score,
235 bm25_score,
236 expansion_score,
237 }
238 })
239 .collect();
240
241 results.sort_by(|a, b| {
242 b.combined_score
243 .partial_cmp(&a.combined_score)
244 .unwrap_or(std::cmp::Ordering::Equal)
245 .then_with(|| {
246 b.bm25_score
247 .partial_cmp(&a.bm25_score)
248 .unwrap_or(std::cmp::Ordering::Equal)
249 })
250 .then_with(|| a.file_path.cmp(&b.file_path))
251 .then_with(|| a.symbol_name.cmp(&b.symbol_name))
252 });
253
254 results.truncate(top_k);
255 results
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::core::bm25_index::{ChunkKind, CodeChunk};
262
263 fn sample_index() -> BM25Index {
264 BM25Index::from_chunks_for_test(vec![
265 CodeChunk {
266 file_path: "login.rs".into(),
267 symbol_name: "login_user".into(),
268 kind: ChunkKind::Function,
269 start_line: 1,
270 end_line: 3,
271 content: "pub fn login_user() { session_token authentication jwt oauth login credential }"
272 .into(),
273 tokens: vec![],
274 token_count: 0,
275 },
276 CodeChunk {
277 file_path: "cache.rs".into(),
278 symbol_name: "memo_cache".into(),
279 kind: ChunkKind::Function,
280 start_line: 1,
281 end_line: 3,
282 content: "pub fn memo_cache() { redis ttl invalidate memo cache }".into(),
283 tokens: vec![],
284 token_count: 0,
285 },
286 CodeChunk {
287 file_path: "auth.rs".into(),
288 symbol_name: "oauth_flow".into(),
289 kind: ChunkKind::Function,
290 start_line: 1,
291 end_line: 3,
292 content: "pub fn oauth_flow() { credential authentication token jwt session }".into(),
293 tokens: vec![],
294 token_count: 0,
295 },
296 ])
297 }
298
299 #[test]
300 fn hybrid_prefers_expansion_overlap() {
301 let index = sample_index();
302
303 let hits = hybrid_retrieve("auth login", &index, 5);
304 assert!(!hits.is_empty());
305 let top_path = hits[0].file_path.clone();
306 assert!(
307 top_path.ends_with("login.rs") || top_path.ends_with("auth.rs"),
308 "expected auth-expanded chunk first, got {hits:?}"
309 );
310 }
311
312 #[test]
313 fn expansion_boosts_related_terms() {
314 let index = sample_index();
315 let hybrid = hybrid_retrieve("jwt auth", &index, 10);
317
318 assert!(!hybrid.is_empty());
319 assert!(
320 hybrid[0].expansion_score >= 0.0,
321 "expansion_score should be non-negative"
322 );
323 }
324
325 #[test]
326 fn empty_query_returns_empty() {
327 let index = sample_index();
328 assert!(hybrid_retrieve("", &index, 5).is_empty());
329 }
330
331 #[test]
332 fn splade_result_fields_populated() {
333 let index = sample_index();
334
335 let hybrid = hybrid_retrieve("auth session", &index, 3);
336 let r = &hybrid[0];
337 assert!(r.chunk_idx < index.chunks.len());
338 assert!(!r.file_path.is_empty());
339 assert!(r.combined_score.is_finite());
340 assert!(r.bm25_score.is_finite());
341 assert!(r.expansion_score.is_finite());
342 }
343}