1use crate::db::{IndexDb, index_db_path};
2use crate::project::ProjectRoot;
3use anyhow::Result;
4use serde::Serialize;
5use strsim::jaro_winkler;
6
7pub const SEMANTIC_BOOST_THRESHOLD: f64 = 0.10;
9pub const SEMANTIC_NEW_RESULT_THRESHOLD: f64 = 0.15;
11pub const SEMANTIC_COUPLING_THRESHOLD: f64 = 0.12;
13
14#[derive(Debug, Clone, Serialize)]
15pub struct SearchResult {
16 pub name: String,
17 pub kind: String,
18 pub file: String,
19 pub line: usize,
20 pub signature: String,
21 pub name_path: String,
22 pub score: f64,
23 pub match_type: String, }
25
26pub fn search_symbols_hybrid(
35 project: &ProjectRoot,
36 query: &str,
37 max_results: usize,
38 fuzzy_threshold: f64,
39) -> Result<Vec<SearchResult>> {
40 search_symbols_hybrid_with_semantic(project, query, max_results, fuzzy_threshold, None)
41}
42
43pub fn search_symbols_hybrid_with_semantic(
45 project: &ProjectRoot,
46 query: &str,
47 max_results: usize,
48 fuzzy_threshold: f64,
49 semantic_scores: Option<&std::collections::HashMap<String, f64>>,
50) -> Result<Vec<SearchResult>> {
51 let db_path = index_db_path(project.as_path());
52 let db = IndexDb::open(&db_path)?;
53
54 let mut seen: std::collections::HashSet<(String, String, i64)> =
55 std::collections::HashSet::new();
56 let mut results: Vec<SearchResult> = Vec::new();
57
58 for (row, file) in db.find_symbols_with_path(query, true, max_results)? {
60 let key = (row.name.clone(), file.clone(), row.line);
61 if seen.insert(key) {
62 results.push(SearchResult {
63 name: row.name,
64 kind: row.kind,
65 file,
66 line: row.line as usize,
67 signature: row.signature,
68 name_path: row.name_path,
69 score: 100.0,
70 match_type: "exact".to_owned(),
71 });
72 }
73 }
74
75 for (row, file, rank) in db.search_symbols_fts(query, max_results)? {
78 let key = (row.name.clone(), file.clone(), row.line);
79 if seen.insert(key) {
80 let fts_score = (80.0 + rank.clamp(-40.0, 0.0)).max(40.0);
82 results.push(SearchResult {
83 name: row.name,
84 kind: row.kind,
85 file,
86 line: row.line as usize,
87 signature: row.signature,
88 name_path: row.name_path,
89 score: fts_score,
90 match_type: "fts".to_owned(),
91 });
92 }
93 }
94
95 let query_lower = query.to_ascii_lowercase();
97 let prefix: String = query_lower.chars().take(2).collect();
98 let fuzzy_candidates = if prefix.len() >= 2 {
99 db.find_symbols_with_path(&prefix, false, 500)?
100 } else {
101 db.find_symbols_with_path(&query_lower, false, 500)?
102 };
103 for (row, file) in fuzzy_candidates {
104 let key = (row.name.clone(), file.clone(), row.line);
105 if seen.contains(&key) {
106 continue;
107 }
108 let sim = jaro_winkler(&query_lower, &row.name.to_ascii_lowercase());
109 if sim >= fuzzy_threshold {
110 seen.insert(key);
111 results.push(SearchResult {
112 name: row.name,
113 kind: row.kind,
114 file,
115 line: row.line as usize,
116 signature: row.signature,
117 name_path: row.name_path,
118 score: sim * 100.0,
119 match_type: "fuzzy".to_owned(),
120 });
121 }
122 }
123
124 if let Some(scores) = semantic_scores {
126 let all_symbols = db.all_symbol_names()?;
131 for (name, kind, file_path, line, signature, name_path) in all_symbols {
132 let key = (name.clone(), file_path.clone(), line);
133 if seen.contains(&key) {
134 let sem_key = format!("{file_path}:{name}");
135 if let Some(&sem_score) = scores.get(&sem_key)
136 && sem_score > SEMANTIC_BOOST_THRESHOLD
137 && let Some(existing) = results
138 .iter_mut()
139 .find(|r| r.name == name && r.file == file_path && r.line == line as usize)
140 {
141 existing.score += (sem_score * 15.0).min(10.0);
142 }
143 continue;
144 }
145 let sem_key = format!("{file_path}:{name}");
146 if let Some(&sem_score) = scores
147 .get(&sem_key)
148 .filter(|&&s| s > SEMANTIC_NEW_RESULT_THRESHOLD)
149 {
150 seen.insert(key);
151 results.push(SearchResult {
152 name,
153 kind,
154 file: file_path,
155 line: line as usize,
156 signature,
157 name_path,
158 score: sem_score * 90.0,
159 match_type: "semantic".to_owned(),
160 });
161 }
162 }
163 }
164
165 results.sort_by(|a, b| {
166 b.score
167 .partial_cmp(&a.score)
168 .unwrap_or(std::cmp::Ordering::Equal)
169 });
170
171 const MAX_PER_FILE: usize = 3;
174 if results.len() > max_results {
175 let mut file_counts: std::collections::HashMap<String, usize> =
176 std::collections::HashMap::new();
177 let mut promoted = Vec::with_capacity(max_results);
178 let mut demoted = Vec::new();
179 for r in results {
180 let count = file_counts.entry(r.file.clone()).or_insert(0);
181 if *count < MAX_PER_FILE {
182 *count += 1;
183 promoted.push(r);
184 } else {
185 demoted.push(r);
186 }
187 }
188 promoted.extend(demoted);
189 results = promoted;
190 }
191
192 results.truncate(max_results);
193 Ok(results)
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::db::{IndexDb, NewSymbol, index_db_path};
200
201 fn make_project_with_symbols() -> (std::path::PathBuf, ProjectRoot) {
204 use std::time::{SystemTime, UNIX_EPOCH};
205 let nanos = SystemTime::now()
206 .duration_since(UNIX_EPOCH)
207 .unwrap()
208 .subsec_nanos();
209 let root = std::env::temp_dir().join(format!("codelens_search_test_{nanos}"));
210 std::fs::create_dir_all(&root).unwrap();
211
212 std::fs::write(root.join("hello.txt"), "hello").unwrap();
214
215 let db_path = index_db_path(&root);
217 let db = IndexDb::open(&db_path).unwrap();
218 let fid = db
219 .upsert_file("main.py", 100, "h1", 10, Some("py"))
220 .unwrap();
221 db.insert_symbols(
222 fid,
223 &[
224 NewSymbol {
225 name: "ServiceManager",
226 kind: "class",
227 line: 1,
228 column_num: 0,
229 start_byte: 0,
230 end_byte: 100,
231 signature: "class ServiceManager:",
232 name_path: "ServiceManager",
233 parent_id: None,
234 end_line: 0,
235 },
236 NewSymbol {
237 name: "run_service",
238 kind: "function",
239 line: 10,
240 column_num: 0,
241 start_byte: 101,
242 end_byte: 200,
243 signature: "def run_service():",
244 name_path: "run_service",
245 parent_id: None,
246 end_line: 0,
247 },
248 NewSymbol {
249 name: "helper",
250 kind: "function",
251 line: 20,
252 column_num: 0,
253 start_byte: 201,
254 end_byte: 300,
255 signature: "def helper():",
256 name_path: "helper",
257 parent_id: None,
258 end_line: 0,
259 },
260 ],
261 )
262 .unwrap();
263
264 let project = ProjectRoot::new(root.to_str().unwrap()).unwrap();
265 (root, project)
266 }
267
268 #[test]
269 fn exact_match_gets_highest_score() {
270 let (_root, project) = make_project_with_symbols();
271 let results = search_symbols_hybrid(&project, "ServiceManager", 10, 0.6).unwrap();
272 assert!(!results.is_empty());
273 assert_eq!(results[0].name, "ServiceManager");
274 assert_eq!(results[0].match_type, "exact");
275 assert_eq!(results[0].score, 100.0);
276 }
277
278 #[test]
279 fn substring_match_returns_bm25_type() {
280 let (_root, project) = make_project_with_symbols();
281 let results = search_symbols_hybrid(&project, "service", 10, 0.99).unwrap();
284 let text_matches: Vec<_> = results
285 .iter()
286 .filter(|r| r.match_type == "substring" || r.match_type == "fts")
287 .collect();
288 assert!(!text_matches.is_empty());
289 }
290
291 #[test]
292 fn fuzzy_match_finds_approximate_name() {
293 let (_root, project) = make_project_with_symbols();
294 let results = search_symbols_hybrid(&project, "helpr", 10, 0.7).unwrap();
296 let fuzzy: Vec<_> = results.iter().filter(|r| r.match_type == "fuzzy").collect();
297 assert!(!fuzzy.is_empty(), "expected a fuzzy match for 'helpr'");
298 assert_eq!(fuzzy[0].name, "helper");
299 }
300
301 #[test]
302 fn results_sorted_by_score_descending() {
303 let (_root, project) = make_project_with_symbols();
304 let results = search_symbols_hybrid(&project, "run_service", 20, 0.5).unwrap();
305 for window in results.windows(2) {
306 assert!(window[0].score >= window[1].score);
307 }
308 }
309
310 #[test]
311 fn no_duplicates_in_results() {
312 let (_root, project) = make_project_with_symbols();
313 let results = search_symbols_hybrid(&project, "ServiceManager", 20, 0.5).unwrap();
314 let mut keys = std::collections::HashSet::new();
315 for r in &results {
316 let key = (r.name.clone(), r.file.clone(), r.line);
317 assert!(keys.insert(key), "duplicate entry found");
318 }
319 }
320
321 #[test]
322 fn semantic_scores_add_new_results() {
323 let (_root, project) = make_project_with_symbols();
324 let mut scores = std::collections::HashMap::new();
325 scores.insert("main.py:helper".to_owned(), 0.8);
327
328 let results = search_symbols_hybrid_with_semantic(
329 &project,
330 "authentication",
331 10,
332 0.99, Some(&scores),
334 )
335 .unwrap();
336
337 let semantic_matches: Vec<_> = results
338 .iter()
339 .filter(|r| r.match_type == "semantic")
340 .collect();
341 assert!(
342 !semantic_matches.is_empty(),
343 "semantic path should surface 'helper' for 'authentication' query"
344 );
345 assert_eq!(semantic_matches[0].name, "helper");
346 assert!(semantic_matches[0].score > 0.0);
347 }
348
349 #[test]
350 fn semantic_scores_boost_existing_results() {
351 let (_root, project) = make_project_with_symbols();
352 let baseline = search_symbols_hybrid(&project, "ServiceManager", 10, 0.5).unwrap();
354 let baseline_score = baseline[0].score;
355
356 let mut scores = std::collections::HashMap::new();
358 scores.insert("main.py:ServiceManager".to_owned(), 0.9);
359
360 let boosted =
361 search_symbols_hybrid_with_semantic(&project, "ServiceManager", 10, 0.5, Some(&scores))
362 .unwrap();
363
364 assert!(
365 boosted[0].score > baseline_score,
366 "semantic boost should increase score: {} > {}",
367 boosted[0].score,
368 baseline_score
369 );
370 }
371
372 #[test]
373 fn semantic_low_scores_filtered_out() {
374 let (_root, project) = make_project_with_symbols();
375 let mut scores = std::collections::HashMap::new();
376 scores.insert("main.py:helper".to_owned(), 0.1);
378
379 let results = search_symbols_hybrid_with_semantic(
380 &project,
381 "unrelated_query_xyz",
382 10,
383 0.99,
384 Some(&scores),
385 )
386 .unwrap();
387
388 let semantic_matches: Vec<_> = results
389 .iter()
390 .filter(|r| r.match_type == "semantic")
391 .collect();
392 assert!(
393 semantic_matches.is_empty(),
394 "low semantic scores should not surface results"
395 );
396 }
397
398 #[test]
399 fn no_duplicates_with_semantic() {
400 let (_root, project) = make_project_with_symbols();
401 let mut scores = std::collections::HashMap::new();
402 scores.insert("main.py:ServiceManager".to_owned(), 0.9);
403 scores.insert("main.py:helper".to_owned(), 0.7);
404
405 let results =
406 search_symbols_hybrid_with_semantic(&project, "ServiceManager", 20, 0.5, Some(&scores))
407 .unwrap();
408
409 let mut keys = std::collections::HashSet::new();
410 for r in &results {
411 let key = (r.name.clone(), r.file.clone(), r.line);
412 assert!(keys.insert(key.clone()), "duplicate entry found: {:?}", key);
413 }
414 }
415}