1use super::parser::{extend_start_to_doc_comments, flatten_symbol_infos, slice_source};
2use super::ranking::{self, prune_to_budget, rank_symbols, RankingContext};
3use super::types::{
4 make_symbol_id, parse_symbol_id, RankedContextResult, SymbolInfo, SymbolKind, SymbolProvenance,
5};
6use super::SymbolIndex;
7use crate::db::IndexDb;
8use crate::project::ProjectRoot;
9use anyhow::Result;
10use std::fs;
11
12impl SymbolIndex {
13 pub(super) fn select_solve_symbols_cached(
21 &self,
22 query: &str,
23 depth: usize,
24 ) -> Result<Vec<SymbolInfo>> {
25 let query_lower = query.to_ascii_lowercase();
26 let query_tokens: Vec<&str> = query_lower
27 .split(|c: char| c.is_whitespace() || c == '_' || c == '-')
28 .filter(|t| t.len() >= 3)
29 .collect();
30
31 let (top_files, importer_files) = {
36 let db = self.reader()?;
37 let all_paths = db.all_file_paths()?;
38
39 let mut file_scores: Vec<(String, usize)> = all_paths
40 .iter()
41 .map(|path| {
42 let path_lower = path.to_ascii_lowercase();
43 let score = query_tokens
44 .iter()
45 .filter(|t| path_lower.contains(**t))
46 .count();
47 (path.clone(), score)
48 })
49 .collect();
50
51 file_scores.sort_by(|a, b| b.1.cmp(&a.1));
52 let top: Vec<String> = file_scores
53 .iter()
54 .filter(|(_, score)| *score > 0)
55 .take(10)
56 .map(|(path, _)| path.clone())
57 .collect();
58
59 let mut importers = Vec::new();
61 if !top.is_empty() && top.len() <= 5 {
62 for file_path in top.iter().take(3) {
63 if let Ok(imp) = db.get_importers(file_path) {
64 for importer_path in imp.into_iter().take(3) {
65 importers.push(importer_path);
66 }
67 }
68 }
69 }
70
71 (top, importers)
72 };
74
75 let mut seen_ids = std::collections::HashSet::new();
76 let mut all_symbols = Vec::new();
77
78 for file_path in &top_files {
80 if let Ok(symbols) = self.get_symbols_overview_cached(file_path, depth) {
81 for sym in symbols {
82 if seen_ids.insert(sym.id.clone()) {
83 all_symbols.push(sym);
84 }
85 }
86 }
87 }
88
89 if let Ok(direct) = self.find_symbol_cached(query, None, false, false, 50) {
91 for sym in direct {
92 if seen_ids.insert(sym.id.clone()) {
93 all_symbols.push(sym);
94 }
95 }
96 }
97
98 for importer_path in &importer_files {
100 if let Ok(symbols) = self.get_symbols_overview_cached(importer_path, 1) {
101 for sym in symbols {
102 if seen_ids.insert(sym.id.clone()) {
103 all_symbols.push(sym);
104 }
105 }
106 }
107 }
108
109 if query_tokens.len() >= 2 {
111 for token in &query_tokens {
112 if let Ok(hits) = self.find_symbol_cached(token, None, false, false, 10) {
113 for sym in hits {
114 if seen_ids.insert(sym.id.clone()) {
115 all_symbols.push(sym);
116 }
117 }
118 }
119 }
120 }
121
122 if all_symbols.is_empty() {
124 return self.find_symbol_cached(query, None, false, false, 500);
125 }
126
127 Ok(all_symbols)
128 }
129
130 pub fn find_symbol_cached(
132 &self,
133 name: &str,
134 file_path: Option<&str>,
135 include_body: bool,
136 exact_match: bool,
137 max_matches: usize,
138 ) -> Result<Vec<SymbolInfo>> {
139 let db = self.reader()?;
140 if let Some((id_file, _id_kind, id_name_path)) = parse_symbol_id(name) {
142 let leaf_name = id_name_path.rsplit('/').next().unwrap_or(id_name_path);
143 let db_rows = db.find_symbols_by_name(leaf_name, Some(id_file), true, max_matches)?;
144 return Self::rows_to_symbol_infos(&self.project, &db, db_rows, include_body);
145 }
146
147 let resolved_fp = file_path.and_then(|fp| {
149 self.project
150 .resolve(fp)
151 .ok()
152 .map(|abs| self.project.to_relative(&abs))
153 });
154 let fp_ref = resolved_fp.as_deref().or(file_path);
155
156 let db_rows = db.find_symbols_by_name(name, fp_ref, exact_match, max_matches)?;
157 Self::rows_to_symbol_infos(&self.project, &db, db_rows, include_body)
158 }
159
160 pub fn get_symbols_overview_cached(
162 &self,
163 path: &str,
164 _depth: usize,
165 ) -> Result<Vec<SymbolInfo>> {
166 let db = self.reader()?;
167 let resolved = self.project.resolve(path)?;
168 if resolved.is_dir() {
169 let prefix = self.project.to_relative(&resolved);
170 let file_groups = db.get_symbols_for_directory(&prefix)?;
172 let mut symbols = Vec::new();
173 for (rel, file_symbols) in file_groups {
174 if file_symbols.is_empty() {
175 continue;
176 }
177 let id = make_symbol_id(&rel, &SymbolKind::File, &rel);
178 symbols.push(SymbolInfo {
179 name: rel.clone(),
180 kind: SymbolKind::File,
181 file_path: rel.clone(),
182 line: 0,
183 column: 0,
184 signature: format!(
185 "{} ({} symbols)",
186 std::path::Path::new(&rel)
187 .file_name()
188 .and_then(|n| n.to_str())
189 .unwrap_or(&rel),
190 file_symbols.len()
191 ),
192 name_path: rel.clone(),
193 id,
194 provenance: SymbolProvenance::from_path(&rel),
195 body: None,
196 children: file_symbols
197 .into_iter()
198 .map(|row| {
199 let kind = SymbolKind::from_str_label(&row.kind);
200 let sid = make_symbol_id(&rel, &kind, &row.name_path);
201 let row_line = row.line as usize;
202 SymbolInfo {
203 name: row.name,
204 kind,
205 file_path: rel.clone(),
206 line: row_line,
207 column: row.column_num as usize,
208 signature: row.signature,
209 name_path: row.name_path,
210 id: sid,
211 provenance: SymbolProvenance::from_path(&rel),
212 body: None,
213 children: Vec::new(),
214 start_byte: row.start_byte as u32,
215 end_byte: row.end_byte as u32,
216 end_line: if row.end_line > 0 {
221 row.end_line as usize
222 } else {
223 row_line
224 },
225 }
226 })
227 .collect(),
228 start_byte: 0,
229 end_byte: 0,
230 end_line: 0,
231 });
232 }
233 return Ok(symbols);
234 }
235
236 let relative = self.project.to_relative(&resolved);
238 let file_row = match db.get_file(&relative)? {
239 Some(row) => row,
240 None => return Ok(Vec::new()),
241 };
242 let db_symbols = db.get_file_symbols(file_row.id)?;
243 Ok(db_symbols
244 .into_iter()
245 .map(|row| {
246 let kind = SymbolKind::from_str_label(&row.kind);
247 let id = make_symbol_id(&relative, &kind, &row.name_path);
248 let row_line = row.line as usize;
249 SymbolInfo {
250 name: row.name,
251 kind,
252 file_path: relative.clone(),
253 provenance: SymbolProvenance::from_path(&relative),
254 line: row_line,
255 column: row.column_num as usize,
256 signature: row.signature,
257 name_path: row.name_path,
258 id,
259 body: None,
260 children: Vec::new(),
261 start_byte: row.start_byte as u32,
262 end_byte: row.end_byte as u32,
263 end_line: if row.end_line > 0 {
264 row.end_line as usize
265 } else {
266 row_line
267 },
268 }
269 })
270 .collect())
271 }
272
273 #[allow(clippy::too_many_arguments)]
277 pub fn get_ranked_context_cached(
278 &self,
279 query: &str,
280 path: Option<&str>,
281 max_tokens: usize,
282 include_body: bool,
283 depth: usize,
284 graph_cache: Option<&crate::import_graph::GraphCache>,
285 semantic_scores: std::collections::HashMap<String, f64>,
286 ) -> Result<RankedContextResult> {
287 self.get_ranked_context_cached_with_query_type(
288 query,
289 path,
290 max_tokens,
291 include_body,
292 depth,
293 graph_cache,
294 semantic_scores,
295 None,
296 )
297 }
298
299 #[allow(clippy::too_many_arguments)]
303 pub fn get_ranked_context_cached_with_query_type(
304 &self,
305 query: &str,
306 path: Option<&str>,
307 max_tokens: usize,
308 include_body: bool,
309 depth: usize,
310 graph_cache: Option<&crate::import_graph::GraphCache>,
311 semantic_scores: std::collections::HashMap<String, f64>,
312 query_type: Option<&str>,
313 ) -> Result<RankedContextResult> {
314 self.get_ranked_context_cached_with_lsp_boost(
315 query,
316 path,
317 max_tokens,
318 include_body,
319 depth,
320 graph_cache,
321 semantic_scores,
322 query_type,
323 std::collections::HashMap::new(),
324 None,
325 )
326 }
327
328 #[allow(clippy::too_many_arguments)]
345 pub fn get_ranked_context_cached_with_lsp_boost(
346 &self,
347 query: &str,
348 path: Option<&str>,
349 max_tokens: usize,
350 include_body: bool,
351 depth: usize,
352 graph_cache: Option<&crate::import_graph::GraphCache>,
353 semantic_scores: std::collections::HashMap<String, f64>,
354 query_type: Option<&str>,
355 mut lsp_boost_refs: std::collections::HashMap<String, Vec<usize>>,
356 lsp_signal_weight: Option<f64>,
357 ) -> Result<RankedContextResult> {
358 for lines in lsp_boost_refs.values_mut() {
362 lines.sort_unstable();
363 lines.dedup();
364 }
365
366 let mut all_symbols = if let Some(path) = path {
367 self.get_symbols_overview_cached(path, depth)?
368 } else {
369 self.select_solve_symbols_cached(query, depth)?
370 };
371
372 if !lsp_boost_refs.is_empty() {
380 let mut seen: std::collections::HashSet<String> =
381 all_symbols.iter().map(|s| s.id.clone()).collect();
382 for extra_path in lsp_boost_refs.keys() {
383 if Some(extra_path.as_str()) == path {
384 continue;
385 }
386 if let Ok(extra_symbols) = self.get_symbols_overview_cached(extra_path, depth) {
387 for sym in extra_symbols {
388 if seen.insert(sym.id.clone()) {
389 all_symbols.push(sym);
390 }
391 }
392 }
393 }
394 }
395
396 let ranking_ctx = match graph_cache {
397 Some(gc) => {
398 let pagerank = gc.file_pagerank_scores(&self.project);
399 if semantic_scores.is_empty() {
400 RankingContext::with_pagerank(pagerank)
401 } else {
402 RankingContext::with_pagerank_and_semantic(query, pagerank, semantic_scores)
403 }
404 }
405 None => {
406 if semantic_scores.is_empty() {
407 RankingContext::text_only()
408 } else {
409 RankingContext::with_pagerank_and_semantic(
410 query,
411 std::collections::HashMap::new(),
412 semantic_scores,
413 )
414 }
415 }
416 };
417
418 let ranking_ctx = if let Some(qt) = query_type {
420 let mut ctx = ranking_ctx;
421 ctx.weights = ranking::weights_for_query_type(qt);
422 ctx
423 } else {
424 ranking_ctx
425 };
426
427 let ranking_ctx = if lsp_boost_refs.is_empty() && lsp_signal_weight.is_none() {
432 ranking_ctx
433 } else {
434 let mut ctx = ranking_ctx;
435 ctx.lsp_boost_refs = lsp_boost_refs;
436 if let Some(w) = lsp_signal_weight {
437 ctx.weights.lsp_signal = w;
438 }
439 ctx
440 };
441
442 let flat_symbols: Vec<SymbolInfo> = all_symbols
443 .into_iter()
444 .flat_map(flatten_symbol_infos)
445 .collect();
446
447 let scored = rank_symbols(query, flat_symbols, &ranking_ctx);
448
449 let (selected, chars_used, pruned_count, last_kept_score) =
450 prune_to_budget(scored, max_tokens, include_body, self.project.as_path());
451
452 Ok(RankedContextResult {
453 query: query.to_owned(),
454 count: selected.len(),
455 symbols: selected,
456 token_budget: max_tokens,
457 chars_used,
458 pruned_count,
459 last_kept_score,
460 })
461 }
462
463 pub(super) fn rows_to_symbol_infos(
466 project: &ProjectRoot,
467 db: &IndexDb,
468 rows: Vec<crate::db::SymbolRow>,
469 include_body: bool,
470 ) -> Result<Vec<SymbolInfo>> {
471 let mut results = Vec::new();
472 let mut path_cache: std::collections::HashMap<i64, String> =
473 std::collections::HashMap::new();
474 for row in rows {
475 let rel_path = match path_cache.get(&row.file_id) {
476 Some(p) => p.clone(),
477 None => {
478 let p = db.get_file_path(row.file_id)?.unwrap_or_default();
479 path_cache.insert(row.file_id, p.clone());
480 p
481 }
482 };
483 let body = if include_body {
484 let abs = project.as_path().join(&rel_path);
485 fs::read_to_string(&abs).ok().map(|source| {
486 let extended_start =
490 extend_start_to_doc_comments(&source, row.start_byte as u32);
491 slice_source(&source, extended_start, row.end_byte as u32)
492 })
493 } else {
494 None
495 };
496 let kind = SymbolKind::from_str_label(&row.kind);
497 let id = make_symbol_id(&rel_path, &kind, &row.name_path);
498 let row_line = row.line as usize;
499 results.push(SymbolInfo {
500 name: row.name,
501 kind,
502 provenance: SymbolProvenance::from_path(&rel_path),
503 file_path: rel_path,
504 line: row_line,
505 column: row.column_num as usize,
506 signature: row.signature,
507 name_path: row.name_path,
508 id,
509 body,
510 children: Vec::new(),
511 start_byte: row.start_byte as u32,
512 end_byte: row.end_byte as u32,
513 end_line: if row.end_line > 0 {
514 row.end_line as usize
515 } else {
516 row_line
517 },
518 });
519 }
520 Ok(results)
521 }
522}