1use crate::db::{IndexDb, index_db_path};
2use crate::project::ProjectRoot;
3use anyhow::Result;
4use serde::Serialize;
5use strsim::jaro_winkler;
6
7pub const SEMANTIC_BOOST_THRESHOLD: f64 = 0.10;
9pub const SEMANTIC_NEW_RESULT_THRESHOLD: f64 = 0.15;
11pub const SEMANTIC_COUPLING_THRESHOLD: f64 = 0.12;
13
14#[derive(Debug, Clone, Serialize)]
15pub struct SearchResult {
16 pub name: String,
17 pub kind: String,
18 pub file: String,
19 pub line: usize,
20 pub signature: String,
21 pub name_path: String,
22 pub score: f64,
23 pub match_type: String, }
25
26pub const PAGERANK_MAX_BOOST: f64 = 5.0;
30
31pub fn search_symbols_hybrid(
40 project: &ProjectRoot,
41 query: &str,
42 max_results: usize,
43 fuzzy_threshold: f64,
44) -> Result<Vec<SearchResult>> {
45 search_symbols_hybrid_with_semantic(project, query, max_results, fuzzy_threshold, None, None)
46}
47
48pub fn search_symbols_hybrid_with_semantic(
55 project: &ProjectRoot,
56 query: &str,
57 max_results: usize,
58 fuzzy_threshold: f64,
59 semantic_scores: Option<&std::collections::HashMap<String, f64>>,
60 pagerank_scores: Option<&std::collections::HashMap<String, f64>>,
61) -> Result<Vec<SearchResult>> {
62 let db_path = index_db_path(project.as_path());
63 let db = IndexDb::open(&db_path)?;
64
65 let mut seen: std::collections::HashSet<(String, String, i64)> =
66 std::collections::HashSet::new();
67 let mut results: Vec<SearchResult> = Vec::new();
68
69 for (row, file) in db.find_symbols_with_path(query, true, max_results)? {
71 let key = (row.name.clone(), file.clone(), row.line);
72 if seen.insert(key) {
73 results.push(SearchResult {
74 name: row.name,
75 kind: row.kind,
76 file,
77 line: row.line as usize,
78 signature: row.signature,
79 name_path: row.name_path,
80 score: 100.0,
81 match_type: "exact".to_owned(),
82 });
83 }
84 }
85
86 for (row, file, rank) in db.search_symbols_fts(query, max_results)? {
89 let key = (row.name.clone(), file.clone(), row.line);
90 if seen.insert(key) {
91 let fts_score = (80.0 + rank.clamp(-40.0, 0.0)).max(40.0);
93 results.push(SearchResult {
94 name: row.name,
95 kind: row.kind,
96 file,
97 line: row.line as usize,
98 signature: row.signature,
99 name_path: row.name_path,
100 score: fts_score,
101 match_type: "fts".to_owned(),
102 });
103 }
104 }
105
106 let query_lower = query.to_ascii_lowercase();
108 let prefix: String = query_lower.chars().take(2).collect();
109 let fuzzy_candidates = if prefix.len() >= 2 {
110 db.find_symbols_with_path(&prefix, false, 500)?
111 } else {
112 db.find_symbols_with_path(&query_lower, false, 500)?
113 };
114 for (row, file) in fuzzy_candidates {
115 let key = (row.name.clone(), file.clone(), row.line);
116 if seen.contains(&key) {
117 continue;
118 }
119 let sim = jaro_winkler(&query_lower, &row.name.to_ascii_lowercase());
120 if sim >= fuzzy_threshold {
121 seen.insert(key);
122 results.push(SearchResult {
123 name: row.name,
124 kind: row.kind,
125 file,
126 line: row.line as usize,
127 signature: row.signature,
128 name_path: row.name_path,
129 score: sim * 100.0,
130 match_type: "fuzzy".to_owned(),
131 });
132 }
133 }
134
135 if let Some(scores) = semantic_scores {
137 let all_symbols = db.all_symbol_names()?;
142 for (name, kind, file_path, line, signature, name_path) in all_symbols {
143 let key = (name.clone(), file_path.clone(), line);
144 if seen.contains(&key) {
145 let sem_key = format!("{file_path}:{name}");
146 if let Some(&sem_score) = scores.get(&sem_key)
147 && sem_score > SEMANTIC_BOOST_THRESHOLD
148 && let Some(existing) = results
149 .iter_mut()
150 .find(|r| r.name == name && r.file == file_path && r.line == line as usize)
151 {
152 existing.score += (sem_score * 15.0).min(10.0);
153 }
154 continue;
155 }
156 let sem_key = format!("{file_path}:{name}");
157 if let Some(&sem_score) = scores
158 .get(&sem_key)
159 .filter(|&&s| s > SEMANTIC_NEW_RESULT_THRESHOLD)
160 {
161 seen.insert(key);
162 results.push(SearchResult {
163 name,
164 kind,
165 file: file_path,
166 line: line as usize,
167 signature,
168 name_path,
169 score: sem_score * 90.0,
170 match_type: "semantic".to_owned(),
171 });
172 }
173 }
174 }
175
176 if let Some(pr) = pagerank_scores {
177 apply_pagerank_boost(&mut results, pr, PAGERANK_MAX_BOOST);
178 }
179
180 results.sort_by(|a, b| {
181 b.score
182 .partial_cmp(&a.score)
183 .unwrap_or(std::cmp::Ordering::Equal)
184 });
185
186 const MAX_PER_FILE: usize = 3;
189 if results.len() > max_results {
190 let mut file_counts: std::collections::HashMap<String, usize> =
191 std::collections::HashMap::new();
192 let mut promoted = Vec::with_capacity(max_results);
193 let mut demoted = Vec::new();
194 for r in results {
195 let count = file_counts.entry(r.file.clone()).or_insert(0);
196 if *count < MAX_PER_FILE {
197 *count += 1;
198 promoted.push(r);
199 } else {
200 demoted.push(r);
201 }
202 }
203 promoted.extend(demoted);
204 results = promoted;
205 }
206
207 results.truncate(max_results);
208 Ok(results)
209}
210
211fn apply_pagerank_boost(
216 results: &mut [SearchResult],
217 pagerank_scores: &std::collections::HashMap<String, f64>,
218 max_boost: f64,
219) {
220 let max_pr = pagerank_scores.values().copied().fold(0.0_f64, f64::max);
221 if max_pr <= 0.0 {
222 return;
223 }
224 for result in results.iter_mut() {
225 if let Some(&pr) = pagerank_scores.get(&result.file) {
226 result.score += (pr / max_pr) * max_boost;
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::db::{IndexDb, NewSymbol, index_db_path};
235
236 fn make_project_with_symbols() -> (tempfile::TempDir, ProjectRoot) {
257 let temp = tempfile::tempdir().expect("create temp dir for search test fixture");
258 let root = temp.path();
259
260 std::fs::write(root.join("hello.txt"), "hello").unwrap();
262
263 let db_path = index_db_path(root);
265 let db = IndexDb::open(&db_path).unwrap();
266 let fid = db
267 .upsert_file("main.py", 100, "h1", 10, Some("py"))
268 .unwrap();
269 db.insert_symbols(
270 fid,
271 &[
272 NewSymbol {
273 name: "ServiceManager",
274 kind: "class",
275 line: 1,
276 column_num: 0,
277 start_byte: 0,
278 end_byte: 100,
279 signature: "class ServiceManager:",
280 name_path: "ServiceManager",
281 parent_id: None,
282 },
283 NewSymbol {
284 name: "run_service",
285 kind: "function",
286 line: 10,
287 column_num: 0,
288 start_byte: 101,
289 end_byte: 200,
290 signature: "def run_service():",
291 name_path: "run_service",
292 parent_id: None,
293 },
294 NewSymbol {
295 name: "helper",
296 kind: "function",
297 line: 20,
298 column_num: 0,
299 start_byte: 201,
300 end_byte: 300,
301 signature: "def helper():",
302 name_path: "helper",
303 parent_id: None,
304 },
305 ],
306 )
307 .unwrap();
308
309 let project = ProjectRoot::new(root.to_str().unwrap()).unwrap();
310 (temp, project)
311 }
312
313 #[test]
314 fn exact_match_gets_highest_score() {
315 let (_root, project) = make_project_with_symbols();
316 let results = search_symbols_hybrid(&project, "ServiceManager", 10, 0.6).unwrap();
317 assert!(!results.is_empty());
318 assert_eq!(results[0].name, "ServiceManager");
319 assert_eq!(results[0].match_type, "exact");
320 assert_eq!(results[0].score, 100.0);
321 }
322
323 #[test]
324 fn substring_match_returns_bm25_type() {
325 let (_root, project) = make_project_with_symbols();
326 let results = search_symbols_hybrid(&project, "service", 10, 0.99).unwrap();
329 let text_matches: Vec<_> = results
330 .iter()
331 .filter(|r| r.match_type == "substring" || r.match_type == "fts")
332 .collect();
333 assert!(!text_matches.is_empty());
334 }
335
336 #[test]
337 fn fuzzy_match_finds_approximate_name() {
338 let (_root, project) = make_project_with_symbols();
339 let results = search_symbols_hybrid(&project, "helpr", 10, 0.7).unwrap();
341 let fuzzy: Vec<_> = results.iter().filter(|r| r.match_type == "fuzzy").collect();
342 assert!(!fuzzy.is_empty(), "expected a fuzzy match for 'helpr'");
343 assert_eq!(fuzzy[0].name, "helper");
344 }
345
346 #[test]
347 fn results_sorted_by_score_descending() {
348 let (_root, project) = make_project_with_symbols();
349 let results = search_symbols_hybrid(&project, "run_service", 20, 0.5).unwrap();
350 for window in results.windows(2) {
351 assert!(window[0].score >= window[1].score);
352 }
353 }
354
355 #[test]
356 fn no_duplicates_in_results() {
357 let (_root, project) = make_project_with_symbols();
358 let results = search_symbols_hybrid(&project, "ServiceManager", 20, 0.5).unwrap();
359 let mut keys = std::collections::HashSet::new();
360 for r in &results {
361 let key = (r.name.clone(), r.file.clone(), r.line);
362 assert!(keys.insert(key), "duplicate entry found");
363 }
364 }
365
366 #[test]
367 fn semantic_scores_add_new_results() {
368 let (_root, project) = make_project_with_symbols();
369 let mut scores = std::collections::HashMap::new();
370 scores.insert("main.py:helper".to_owned(), 0.8);
372
373 let results = search_symbols_hybrid_with_semantic(
374 &project,
375 "authentication",
376 10,
377 0.99, Some(&scores),
379 None,
380 )
381 .unwrap();
382
383 let semantic_matches: Vec<_> = results
384 .iter()
385 .filter(|r| r.match_type == "semantic")
386 .collect();
387 assert!(
388 !semantic_matches.is_empty(),
389 "semantic path should surface 'helper' for 'authentication' query"
390 );
391 assert_eq!(semantic_matches[0].name, "helper");
392 assert!(semantic_matches[0].score > 0.0);
393 }
394
395 #[test]
396 fn semantic_scores_boost_existing_results() {
397 let (_root, project) = make_project_with_symbols();
398 let baseline = search_symbols_hybrid(&project, "ServiceManager", 10, 0.5).unwrap();
400 let baseline_score = baseline[0].score;
401
402 let mut scores = std::collections::HashMap::new();
404 scores.insert("main.py:ServiceManager".to_owned(), 0.9);
405
406 let boosted = search_symbols_hybrid_with_semantic(
407 &project,
408 "ServiceManager",
409 10,
410 0.5,
411 Some(&scores),
412 None,
413 )
414 .unwrap();
415
416 assert!(
417 boosted[0].score > baseline_score,
418 "semantic boost should increase score: {} > {}",
419 boosted[0].score,
420 baseline_score
421 );
422 }
423
424 #[test]
425 fn semantic_low_scores_filtered_out() {
426 let (_root, project) = make_project_with_symbols();
427 let mut scores = std::collections::HashMap::new();
428 scores.insert("main.py:helper".to_owned(), 0.1);
430
431 let results = search_symbols_hybrid_with_semantic(
432 &project,
433 "unrelated_query_xyz",
434 10,
435 0.99,
436 Some(&scores),
437 None,
438 )
439 .unwrap();
440
441 let semantic_matches: Vec<_> = results
442 .iter()
443 .filter(|r| r.match_type == "semantic")
444 .collect();
445 assert!(
446 semantic_matches.is_empty(),
447 "low semantic scores should not surface results"
448 );
449 }
450
451 #[test]
452 fn no_duplicates_with_semantic() {
453 let (_root, project) = make_project_with_symbols();
454 let mut scores = std::collections::HashMap::new();
455 scores.insert("main.py:ServiceManager".to_owned(), 0.9);
456 scores.insert("main.py:helper".to_owned(), 0.7);
457
458 let results = search_symbols_hybrid_with_semantic(
459 &project,
460 "ServiceManager",
461 20,
462 0.5,
463 Some(&scores),
464 None,
465 )
466 .unwrap();
467
468 let mut keys = std::collections::HashSet::new();
469 for r in &results {
470 let key = (r.name.clone(), r.file.clone(), r.line);
471 assert!(keys.insert(key.clone()), "duplicate entry found: {:?}", key);
472 }
473 }
474
475 #[test]
476 fn pagerank_boost_max_normalized_by_top_file() {
477 let mut results = vec![
478 SearchResult {
479 name: "a".into(),
480 kind: "function".into(),
481 file: "popular.py".into(),
482 line: 1,
483 signature: "".into(),
484 name_path: "a".into(),
485 score: 50.0,
486 match_type: "fts".into(),
487 },
488 SearchResult {
489 name: "b".into(),
490 kind: "function".into(),
491 file: "obscure.py".into(),
492 line: 2,
493 signature: "".into(),
494 name_path: "b".into(),
495 score: 50.0,
496 match_type: "fts".into(),
497 },
498 ];
499 let mut pr = std::collections::HashMap::new();
500 pr.insert("popular.py".into(), 0.4);
501 pr.insert("obscure.py".into(), 0.05);
502 apply_pagerank_boost(&mut results, &pr, 5.0);
503 assert!((results[0].score - 55.0).abs() < 1e-6);
505 assert!((results[1].score - 50.625).abs() < 1e-6);
506 }
507
508 #[test]
509 fn pagerank_boost_skips_unranked_files() {
510 let mut results = vec![SearchResult {
511 name: "a".into(),
512 kind: "function".into(),
513 file: "unmapped.py".into(),
514 line: 1,
515 signature: "".into(),
516 name_path: "a".into(),
517 score: 50.0,
518 match_type: "fts".into(),
519 }];
520 let mut pr = std::collections::HashMap::new();
521 pr.insert("other.py".into(), 0.3);
522 apply_pagerank_boost(&mut results, &pr, 5.0);
523 assert_eq!(results[0].score, 50.0);
524 }
525
526 #[test]
527 fn pagerank_boost_zero_max_is_noop() {
528 let mut results = vec![SearchResult {
529 name: "a".into(),
530 kind: "function".into(),
531 file: "x.py".into(),
532 line: 1,
533 signature: "".into(),
534 name_path: "a".into(),
535 score: 50.0,
536 match_type: "fts".into(),
537 }];
538 let pr = std::collections::HashMap::new();
539 apply_pagerank_boost(&mut results, &pr, 5.0);
540 assert_eq!(results[0].score, 50.0);
541 }
542}