1use std::collections::{HashMap, HashSet};
17use std::path::Path;
18
19use super::bm25_index::ChunkKind;
20use super::hybrid_search::HybridResult;
21
22const DEFINITION_BOOST_MULTIPLIER: f64 = 3.0;
25const FILE_COHERENCE_FRAC: f64 = 0.2;
26const SATURATION_DECAY: f64 = 0.5;
27const SATURATION_THRESHOLD: usize = 1;
28
29const STRONG_PENALTY: f64 = 0.3;
30const MODERATE_PENALTY: f64 = 0.5;
31const MILD_PENALTY: f64 = 0.7;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum QueryType {
37 Symbol,
38 NaturalLanguage,
39 Architecture,
40}
41
42pub fn classify_query(query: &str) -> QueryType {
48 let trimmed = query.trim();
49 if trimmed.is_empty() {
50 return QueryType::NaturalLanguage;
51 }
52
53 if is_symbol_query(trimmed) {
54 return QueryType::Symbol;
55 }
56
57 let lower = trimmed.to_lowercase();
58 if is_architecture_query(&lower) {
59 return QueryType::Architecture;
60 }
61
62 QueryType::NaturalLanguage
63}
64
65fn is_symbol_query(query: &str) -> bool {
66 let tokens: Vec<&str> = query.split_whitespace().collect();
67 if tokens.len() != 1 {
68 return false;
69 }
70 let token = tokens[0];
71
72 if token.contains("::")
74 || (token.contains('.') && token.chars().any(char::is_uppercase))
75 || token.contains("->")
76 {
77 return true;
78 }
79
80 if token.starts_with('_') && token.len() > 1 {
82 return true;
83 }
84
85 if token.len() > 2
87 && token
88 .chars()
89 .all(|c| c.is_uppercase() || c == '_' || c.is_ascii_digit())
90 && token.contains('_')
91 {
92 return true;
93 }
94
95 let has_lower_to_upper = token
97 .as_bytes()
98 .windows(2)
99 .any(|w| w[0].is_ascii_lowercase() && w[1].is_ascii_uppercase());
100 let starts_upper = token.starts_with(|c: char| c.is_uppercase());
101
102 if token.contains('_')
104 && token.len() > 2
105 && token.chars().all(|c| c.is_alphanumeric() || c == '_')
106 {
107 return true;
108 }
109
110 has_lower_to_upper
111 || (starts_upper && token.len() > 1 && token[1..].contains(char::is_lowercase))
112}
113
114fn is_architecture_query(lower: &str) -> bool {
115 const ARCH_KEYWORDS: &[&str] = &[
116 "how does",
117 "how is",
118 "where is",
119 "where are",
120 "architecture",
121 "design pattern",
122 "data flow",
123 "control flow",
124 "module structure",
125 "component",
126 "layer",
127 "pipeline",
128 ];
129 ARCH_KEYWORDS.iter().any(|kw| lower.contains(kw))
130}
131
132pub fn resolve_weights(query_type: QueryType) -> (f64, f64) {
135 match query_type {
136 QueryType::Symbol => (1.4, 0.6),
137 QueryType::NaturalLanguage => (1.0, 1.0),
138 QueryType::Architecture => (0.6, 1.4),
139 }
140}
141
142pub fn rerank_pipeline(results: &mut Vec<HybridResult>, query: &str, top_k: usize) {
149 if results.is_empty() {
150 return;
151 }
152
153 let query_type = classify_query(query);
154
155 definition_boost(results, query, query_type);
156 file_coherence_boost(results);
157 apply_noise_penalties(results);
158 *results = apply_diversity(std::mem::take(results), top_k);
159}
160
161fn definition_boost(results: &mut [HybridResult], query: &str, query_type: QueryType) {
164 if query_type != QueryType::Symbol {
165 return;
166 }
167
168 let symbol = extract_symbol_name(query);
169 if symbol.is_empty() {
170 return;
171 }
172
173 let max_score = results.iter().map(|r| r.rrf_score).fold(0.0_f64, f64::max);
174 if max_score == 0.0 {
175 return;
176 }
177
178 let boost = max_score * DEFINITION_BOOST_MULTIPLIER;
179 let symbol_lower = symbol.to_lowercase();
180
181 for result in results.iter_mut() {
182 if is_defining_chunk(result, &symbol_lower) {
183 result.rrf_score += boost;
184 }
185 }
186}
187
188fn extract_symbol_name(query: &str) -> &str {
189 let trimmed = query.trim();
190 if let Some(pos) = trimmed.rfind("::") {
192 return &trimmed[pos + 2..];
193 }
194 if let Some(pos) = trimmed.rfind('.') {
196 return &trimmed[pos + 1..];
197 }
198 if let Some(pos) = trimmed.rfind("->") {
200 return &trimmed[pos + 2..];
201 }
202 trimmed
203}
204
205fn is_defining_chunk(result: &HybridResult, symbol_lower: &str) -> bool {
206 match result.kind {
207 ChunkKind::Other => false,
208 _ => result.symbol_name.to_lowercase().contains(symbol_lower),
209 }
210}
211
212fn file_coherence_boost(results: &mut [HybridResult]) {
215 if results.len() < 2 {
216 return;
217 }
218
219 let max_score = results.iter().map(|r| r.rrf_score).fold(0.0_f64, f64::max);
220 if max_score == 0.0 {
221 return;
222 }
223
224 let mut file_scores: HashMap<String, f64> = HashMap::new();
225 for r in results.iter() {
226 *file_scores.entry(r.file_path.clone()).or_insert(0.0) += r.rrf_score;
227 }
228
229 let max_file_score = file_scores.values().copied().fold(0.0_f64, f64::max);
230 if max_file_score == 0.0 {
231 return;
232 }
233
234 let boost_unit = max_score * FILE_COHERENCE_FRAC;
235 let mut seen: HashSet<String> = HashSet::new();
236
237 for result in results.iter_mut() {
238 if seen.insert(result.file_path.clone()) {
239 let file_score = file_scores.get(&result.file_path).copied().unwrap_or(0.0);
240 result.rrf_score += boost_unit * file_score / max_file_score;
241 }
242 }
243}
244
245fn apply_noise_penalties(results: &mut [HybridResult]) {
248 for result in results.iter_mut() {
249 let penalty = path_penalty(&result.file_path);
250 if penalty < 1.0 {
251 result.rrf_score *= penalty;
252 }
253 }
254}
255
256fn path_penalty(file_path: &str) -> f64 {
257 let normalized = file_path.replace('\\', "/");
258 let mut penalty = 1.0;
259
260 if is_test_file(&normalized) {
261 penalty *= STRONG_PENALTY;
262 }
263 if is_compat_legacy(&normalized) {
264 penalty *= STRONG_PENALTY;
265 }
266 if is_example_docs(&normalized) {
267 penalty *= STRONG_PENALTY;
268 }
269 if is_reexport_barrel(&normalized) {
270 penalty *= MODERATE_PENALTY;
271 }
272 if is_type_stub(&normalized) {
273 penalty *= MILD_PENALTY;
274 }
275
276 penalty
277}
278
279fn is_test_file(path: &str) -> bool {
280 let filename = Path::new(path)
281 .file_name()
282 .and_then(|f| f.to_str())
283 .unwrap_or("");
284
285 if filename.starts_with("test_") || filename.contains("_test.") {
287 return true;
288 }
289 if filename.contains(".test.") || filename.contains(".spec.") {
291 return true;
292 }
293 if filename.ends_with("Test.java")
295 || filename.ends_with("Tests.java")
296 || filename.ends_with("Test.kt")
297 || filename.ends_with("Test.cs")
298 || filename.ends_with("Tests.swift")
299 {
300 return true;
301 }
302 if filename.ends_with("_spec.rb") {
304 return true;
305 }
306
307 path.contains("/tests/")
309 || path.contains("/test/")
310 || path.contains("/__tests__/")
311 || path.contains("/spec/")
312 || path.contains("/testing/")
313 || path.starts_with("tests/")
314 || path.starts_with("test/")
315}
316
317fn is_compat_legacy(path: &str) -> bool {
318 path.contains("/compat/")
319 || path.contains("/_compat/")
320 || path.contains("/legacy/")
321 || path.contains("/deprecated/")
322}
323
324fn is_example_docs(path: &str) -> bool {
325 path.contains("/examples/")
326 || path.contains("/example/")
327 || path.contains("/_examples/")
328 || path.contains("/docs_src/")
329 || path.starts_with("examples/")
330 || path.starts_with("example/")
331}
332
333fn is_reexport_barrel(path: &str) -> bool {
334 let filename = Path::new(path)
335 .file_name()
336 .and_then(|f| f.to_str())
337 .unwrap_or("");
338 filename == "__init__.py" || filename == "package-info.java" || filename == "index.ts"
339}
340
341#[allow(clippy::case_sensitive_file_extension_comparisons)]
342fn is_type_stub(path: &str) -> bool {
343 let lower = path.to_ascii_lowercase();
344 lower.ends_with(".d.ts") || lower.ends_with(".pyi")
345}
346
347fn apply_diversity(mut results: Vec<HybridResult>, top_k: usize) -> Vec<HybridResult> {
350 if results.is_empty() {
351 return results;
352 }
353
354 results.sort_by(|a, b| {
355 b.rrf_score
356 .partial_cmp(&a.rrf_score)
357 .unwrap_or(std::cmp::Ordering::Equal)
358 });
359
360 let mut selected: Vec<HybridResult> = Vec::with_capacity(top_k);
361 let mut file_count: HashMap<&str, usize> = HashMap::new();
362 let mut remaining: Vec<(usize, f64)> = results
363 .iter()
364 .enumerate()
365 .map(|(i, r)| (i, r.rrf_score))
366 .collect();
367
368 while selected.len() < top_k && !remaining.is_empty() {
369 let mut best_idx = 0;
371 let mut best_effective = f64::NEG_INFINITY;
372
373 for (pos, &(orig_idx, base_score)) in remaining.iter().enumerate() {
374 let file = results[orig_idx].file_path.as_str();
375 let count = file_count.get(file).copied().unwrap_or(0);
376 let effective = if count >= SATURATION_THRESHOLD {
377 let excess = count - SATURATION_THRESHOLD + 1;
378 base_score * SATURATION_DECAY.powi(excess as i32)
379 } else {
380 base_score
381 };
382
383 if effective > best_effective {
384 best_effective = effective;
385 best_idx = pos;
386 }
387 }
388
389 let (orig_idx, _) = remaining.remove(best_idx);
390 let file = results[orig_idx].file_path.as_str();
391 *file_count.entry(file).or_insert(0) += 1;
392 selected.push(results[orig_idx].clone());
393 }
394
395 selected
396}
397
398#[cfg(test)]
401mod tests {
402 use super::*;
403
404 fn make_result(file: &str, symbol: &str, kind: ChunkKind, score: f64) -> HybridResult {
405 HybridResult {
406 file_path: file.to_string(),
407 symbol_name: symbol.to_string(),
408 kind,
409 start_line: 1,
410 end_line: 10,
411 snippet: String::new(),
412 rrf_score: score,
413 bm25_score: Some(score),
414 dense_score: None,
415 bm25_rank: Some(1),
416 dense_rank: None,
417 }
418 }
419
420 #[test]
421 fn classify_symbol_queries() {
422 assert_eq!(classify_query("AuthService"), QueryType::Symbol);
423 assert_eq!(classify_query("Foo::bar"), QueryType::Symbol);
424 assert_eq!(classify_query("get_user_by_id"), QueryType::Symbol);
425 assert_eq!(classify_query("_private"), QueryType::Symbol);
426 assert_eq!(classify_query("HTTP_CLIENT"), QueryType::Symbol);
427 assert_eq!(classify_query("getUserById"), QueryType::Symbol);
428 }
429
430 #[test]
431 fn classify_nl_queries() {
432 assert_eq!(
433 classify_query("authentication flow"),
434 QueryType::NaturalLanguage
435 );
436 assert_eq!(
437 classify_query("save model to disk"),
438 QueryType::NaturalLanguage
439 );
440 assert_eq!(classify_query("error handling"), QueryType::NaturalLanguage);
441 }
442
443 #[test]
444 fn classify_architecture_queries() {
445 assert_eq!(
446 classify_query("how does auth work"),
447 QueryType::Architecture
448 );
449 assert_eq!(
450 classify_query("where is the data flow"),
451 QueryType::Architecture
452 );
453 assert_eq!(
454 classify_query("module structure overview"),
455 QueryType::Architecture
456 );
457 }
458
459 #[test]
460 fn definition_boost_works() {
461 let mut results = vec![
462 make_result("src/auth.rs", "authenticate", ChunkKind::Function, 0.5),
463 make_result("src/main.rs", "main", ChunkKind::Function, 0.8),
464 make_result("src/auth.rs", "AuthService", ChunkKind::Struct, 0.4),
465 ];
466
467 definition_boost(&mut results, "AuthService", QueryType::Symbol);
468
469 assert!(results[2].rrf_score > results[1].rrf_score);
471 }
472
473 #[test]
474 fn noise_penalty_applies() {
475 let mut results = vec![
476 make_result("src/auth.rs", "auth", ChunkKind::Function, 1.0),
477 make_result("tests/test_auth.rs", "test_auth", ChunkKind::Function, 1.0),
478 ];
479
480 apply_noise_penalties(&mut results);
481
482 assert!(results[0].rrf_score > results[1].rrf_score);
483 assert!((results[1].rrf_score - STRONG_PENALTY).abs() < 0.001);
484 }
485
486 #[test]
487 fn file_coherence_boosts_multi_chunk_files() {
488 let mut results = vec![
489 make_result("src/auth.rs", "login", ChunkKind::Function, 0.5),
490 make_result("src/auth.rs", "logout", ChunkKind::Function, 0.4),
491 make_result("src/main.rs", "main", ChunkKind::Function, 0.6),
492 ];
493
494 file_coherence_boost(&mut results);
495
496 assert!(results[0].rrf_score > 0.5);
498 }
499
500 #[test]
501 fn diversity_limits_same_file() {
502 let results = vec![
503 make_result("src/big.rs", "fn1", ChunkKind::Function, 1.0),
504 make_result("src/big.rs", "fn2", ChunkKind::Function, 0.9),
505 make_result("src/big.rs", "fn3", ChunkKind::Function, 0.8),
506 make_result("src/other.rs", "fn4", ChunkKind::Function, 0.7),
507 ];
508
509 let diverse = apply_diversity(results, 3);
510 let files: Vec<&str> = diverse.iter().map(|r| r.file_path.as_str()).collect();
512 assert!(files.contains(&"src/other.rs"));
513 }
514
515 #[test]
516 fn extract_symbol_from_qualified() {
517 assert_eq!(extract_symbol_name("Foo::bar"), "bar");
518 assert_eq!(extract_symbol_name("obj.method"), "method");
519 assert_eq!(extract_symbol_name("ptr->field"), "field");
520 assert_eq!(extract_symbol_name("SimpleIdent"), "SimpleIdent");
521 }
522
523 #[test]
524 fn path_penalties_correct() {
525 assert!((path_penalty("src/auth.rs") - 1.0).abs() < 0.001);
526 assert!((path_penalty("tests/test_auth.py") - STRONG_PENALTY).abs() < 0.001);
527 assert!((path_penalty("src/compat/old.rs") - STRONG_PENALTY).abs() < 0.001);
528 assert!((path_penalty("src/types.d.ts") - MILD_PENALTY).abs() < 0.001);
529 assert!((path_penalty("src/__init__.py") - MODERATE_PENALTY).abs() < 0.001);
530 }
531}