1use crate::db::{IndexDb, index_db_path};
2use crate::project::ProjectRoot;
3use anyhow::Result;
4use serde::Serialize;
5use strsim::jaro_winkler;
6
7pub const SEMANTIC_BOOST_THRESHOLD: f64 = 0.10;
9pub const SEMANTIC_NEW_RESULT_THRESHOLD: f64 = 0.15;
11pub const SEMANTIC_COUPLING_THRESHOLD: f64 = 0.12;
13
14#[derive(Debug, Clone, Serialize)]
15pub struct SearchResult {
16 pub name: String,
17 pub kind: String,
18 pub file: String,
19 pub line: usize,
20 pub signature: String,
21 pub name_path: String,
22 pub score: f64,
23 pub match_type: String, }
25
26pub fn search_symbols_hybrid(
35 project: &ProjectRoot,
36 query: &str,
37 max_results: usize,
38 fuzzy_threshold: f64,
39) -> Result<Vec<SearchResult>> {
40 search_symbols_hybrid_with_semantic(project, query, max_results, fuzzy_threshold, None)
41}
42
43pub fn search_symbols_hybrid_with_semantic(
45 project: &ProjectRoot,
46 query: &str,
47 max_results: usize,
48 fuzzy_threshold: f64,
49 semantic_scores: Option<&std::collections::HashMap<String, f64>>,
50) -> Result<Vec<SearchResult>> {
51 let db_path = index_db_path(project.as_path());
52 let db = IndexDb::open(&db_path)?;
53
54 let mut seen: std::collections::HashSet<(String, String, i64)> =
55 std::collections::HashSet::new();
56 let mut results: Vec<SearchResult> = Vec::new();
57
58 for (row, file) in db.find_symbols_with_path(query, true, max_results)? {
60 let key = (row.name.clone(), file.clone(), row.line);
61 if seen.insert(key) {
62 results.push(SearchResult {
63 name: row.name,
64 kind: row.kind,
65 file,
66 line: row.line as usize,
67 signature: row.signature,
68 name_path: row.name_path,
69 score: 100.0,
70 match_type: "exact".to_owned(),
71 });
72 }
73 }
74
75 for (row, file, rank) in db.search_symbols_fts(query, max_results)? {
78 let key = (row.name.clone(), file.clone(), row.line);
79 if seen.insert(key) {
80 let fts_score = (80.0 + rank.clamp(-40.0, 0.0)).max(40.0);
82 results.push(SearchResult {
83 name: row.name,
84 kind: row.kind,
85 file,
86 line: row.line as usize,
87 signature: row.signature,
88 name_path: row.name_path,
89 score: fts_score,
90 match_type: "fts".to_owned(),
91 });
92 }
93 }
94
95 let query_lower = query.to_ascii_lowercase();
97 let prefix: String = query_lower.chars().take(2).collect();
98 let fuzzy_candidates = if prefix.len() >= 2 {
99 db.find_symbols_with_path(&prefix, false, 500)?
100 } else {
101 db.find_symbols_with_path(&query_lower, false, 500)?
102 };
103 for (row, file) in fuzzy_candidates {
104 let key = (row.name.clone(), file.clone(), row.line);
105 if seen.contains(&key) {
106 continue;
107 }
108 let sim = jaro_winkler(&query_lower, &row.name.to_ascii_lowercase());
109 if sim >= fuzzy_threshold {
110 seen.insert(key);
111 results.push(SearchResult {
112 name: row.name,
113 kind: row.kind,
114 file,
115 line: row.line as usize,
116 signature: row.signature,
117 name_path: row.name_path,
118 score: sim * 100.0,
119 match_type: "fuzzy".to_owned(),
120 });
121 }
122 }
123
124 if let Some(scores) = semantic_scores {
126 let all_symbols = db.all_symbol_names()?;
131 for (name, kind, file_path, line, signature, name_path) in all_symbols {
132 let key = (name.clone(), file_path.clone(), line);
133 if seen.contains(&key) {
134 let sem_key = format!("{file_path}:{name}");
135 if let Some(&sem_score) = scores.get(&sem_key)
136 && sem_score > SEMANTIC_BOOST_THRESHOLD
137 && let Some(existing) = results
138 .iter_mut()
139 .find(|r| r.name == name && r.file == file_path && r.line == line as usize)
140 {
141 existing.score += (sem_score * 15.0).min(10.0);
142 }
143 continue;
144 }
145 let sem_key = format!("{file_path}:{name}");
146 if let Some(&sem_score) = scores
147 .get(&sem_key)
148 .filter(|&&s| s > SEMANTIC_NEW_RESULT_THRESHOLD)
149 {
150 seen.insert(key);
151 results.push(SearchResult {
152 name,
153 kind,
154 file: file_path,
155 line: line as usize,
156 signature,
157 name_path,
158 score: sem_score * 90.0,
159 match_type: "semantic".to_owned(),
160 });
161 }
162 }
163 }
164
165 results.sort_by(|a, b| {
166 b.score
167 .partial_cmp(&a.score)
168 .unwrap_or(std::cmp::Ordering::Equal)
169 });
170
171 const MAX_PER_FILE: usize = 3;
174 if results.len() > max_results {
175 let mut file_counts: std::collections::HashMap<String, usize> =
176 std::collections::HashMap::new();
177 let mut promoted = Vec::with_capacity(max_results);
178 let mut demoted = Vec::new();
179 for r in results {
180 let count = file_counts.entry(r.file.clone()).or_insert(0);
181 if *count < MAX_PER_FILE {
182 *count += 1;
183 promoted.push(r);
184 } else {
185 demoted.push(r);
186 }
187 }
188 promoted.extend(demoted);
189 results = promoted;
190 }
191
192 results.truncate(max_results);
193 Ok(results)
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::db::{IndexDb, NewSymbol, index_db_path};
200
201 fn make_project_with_symbols() -> (std::path::PathBuf, ProjectRoot) {
204 use std::time::{SystemTime, UNIX_EPOCH};
205 let nanos = SystemTime::now()
206 .duration_since(UNIX_EPOCH)
207 .unwrap()
208 .subsec_nanos();
209 let root = std::env::temp_dir().join(format!("codelens_search_test_{nanos}"));
210 std::fs::create_dir_all(&root).unwrap();
211
212 std::fs::write(root.join("hello.txt"), "hello").unwrap();
214
215 let db_path = index_db_path(&root);
217 let db = IndexDb::open(&db_path).unwrap();
218 let fid = db
219 .upsert_file("main.py", 100, "h1", 10, Some("py"))
220 .unwrap();
221 db.insert_symbols(
222 fid,
223 &[
224 NewSymbol {
225 name: "ServiceManager",
226 kind: "class",
227 line: 1,
228 column_num: 0,
229 start_byte: 0,
230 end_byte: 100,
231 signature: "class ServiceManager:",
232 name_path: "ServiceManager",
233 parent_id: None,
234 },
235 NewSymbol {
236 name: "run_service",
237 kind: "function",
238 line: 10,
239 column_num: 0,
240 start_byte: 101,
241 end_byte: 200,
242 signature: "def run_service():",
243 name_path: "run_service",
244 parent_id: None,
245 },
246 NewSymbol {
247 name: "helper",
248 kind: "function",
249 line: 20,
250 column_num: 0,
251 start_byte: 201,
252 end_byte: 300,
253 signature: "def helper():",
254 name_path: "helper",
255 parent_id: None,
256 },
257 ],
258 )
259 .unwrap();
260
261 let project = ProjectRoot::new(root.to_str().unwrap()).unwrap();
262 (root, project)
263 }
264
265 #[test]
266 fn exact_match_gets_highest_score() {
267 let (_root, project) = make_project_with_symbols();
268 let results = search_symbols_hybrid(&project, "ServiceManager", 10, 0.6).unwrap();
269 assert!(!results.is_empty());
270 assert_eq!(results[0].name, "ServiceManager");
271 assert_eq!(results[0].match_type, "exact");
272 assert_eq!(results[0].score, 100.0);
273 }
274
275 #[test]
276 fn substring_match_returns_bm25_type() {
277 let (_root, project) = make_project_with_symbols();
278 let results = search_symbols_hybrid(&project, "service", 10, 0.99).unwrap();
281 let text_matches: Vec<_> = results
282 .iter()
283 .filter(|r| r.match_type == "substring" || r.match_type == "fts")
284 .collect();
285 assert!(!text_matches.is_empty());
286 }
287
288 #[test]
289 fn fuzzy_match_finds_approximate_name() {
290 let (_root, project) = make_project_with_symbols();
291 let results = search_symbols_hybrid(&project, "helpr", 10, 0.7).unwrap();
293 let fuzzy: Vec<_> = results.iter().filter(|r| r.match_type == "fuzzy").collect();
294 assert!(!fuzzy.is_empty(), "expected a fuzzy match for 'helpr'");
295 assert_eq!(fuzzy[0].name, "helper");
296 }
297
298 #[test]
299 fn results_sorted_by_score_descending() {
300 let (_root, project) = make_project_with_symbols();
301 let results = search_symbols_hybrid(&project, "run_service", 20, 0.5).unwrap();
302 for window in results.windows(2) {
303 assert!(window[0].score >= window[1].score);
304 }
305 }
306
307 #[test]
308 fn no_duplicates_in_results() {
309 let (_root, project) = make_project_with_symbols();
310 let results = search_symbols_hybrid(&project, "ServiceManager", 20, 0.5).unwrap();
311 let mut keys = std::collections::HashSet::new();
312 for r in &results {
313 let key = (r.name.clone(), r.file.clone(), r.line);
314 assert!(keys.insert(key), "duplicate entry found");
315 }
316 }
317
318 #[test]
319 fn semantic_scores_add_new_results() {
320 let (_root, project) = make_project_with_symbols();
321 let mut scores = std::collections::HashMap::new();
322 scores.insert("main.py:helper".to_owned(), 0.8);
324
325 let results = search_symbols_hybrid_with_semantic(
326 &project,
327 "authentication",
328 10,
329 0.99, Some(&scores),
331 )
332 .unwrap();
333
334 let semantic_matches: Vec<_> = results
335 .iter()
336 .filter(|r| r.match_type == "semantic")
337 .collect();
338 assert!(
339 !semantic_matches.is_empty(),
340 "semantic path should surface 'helper' for 'authentication' query"
341 );
342 assert_eq!(semantic_matches[0].name, "helper");
343 assert!(semantic_matches[0].score > 0.0);
344 }
345
346 #[test]
347 fn semantic_scores_boost_existing_results() {
348 let (_root, project) = make_project_with_symbols();
349 let baseline = search_symbols_hybrid(&project, "ServiceManager", 10, 0.5).unwrap();
351 let baseline_score = baseline[0].score;
352
353 let mut scores = std::collections::HashMap::new();
355 scores.insert("main.py:ServiceManager".to_owned(), 0.9);
356
357 let boosted =
358 search_symbols_hybrid_with_semantic(&project, "ServiceManager", 10, 0.5, Some(&scores))
359 .unwrap();
360
361 assert!(
362 boosted[0].score > baseline_score,
363 "semantic boost should increase score: {} > {}",
364 boosted[0].score,
365 baseline_score
366 );
367 }
368
369 #[test]
370 fn semantic_low_scores_filtered_out() {
371 let (_root, project) = make_project_with_symbols();
372 let mut scores = std::collections::HashMap::new();
373 scores.insert("main.py:helper".to_owned(), 0.1);
375
376 let results = search_symbols_hybrid_with_semantic(
377 &project,
378 "unrelated_query_xyz",
379 10,
380 0.99,
381 Some(&scores),
382 )
383 .unwrap();
384
385 let semantic_matches: Vec<_> = results
386 .iter()
387 .filter(|r| r.match_type == "semantic")
388 .collect();
389 assert!(
390 semantic_matches.is_empty(),
391 "low semantic scores should not surface results"
392 );
393 }
394
395 #[test]
396 fn no_duplicates_with_semantic() {
397 let (_root, project) = make_project_with_symbols();
398 let mut scores = std::collections::HashMap::new();
399 scores.insert("main.py:ServiceManager".to_owned(), 0.9);
400 scores.insert("main.py:helper".to_owned(), 0.7);
401
402 let results =
403 search_symbols_hybrid_with_semantic(&project, "ServiceManager", 20, 0.5, Some(&scores))
404 .unwrap();
405
406 let mut keys = std::collections::HashSet::new();
407 for r in &results {
408 let key = (r.name.clone(), r.file.clone(), r.line);
409 assert!(keys.insert(key.clone()), "duplicate entry found: {:?}", key);
410 }
411 }
412}