1use super::SymbolIndex;
2use super::parser::{flatten_symbol_infos, slice_source};
3use super::ranking::{self, RankingContext, prune_to_budget, rank_symbols};
4use super::types::{
5 RankedContextResult, SymbolInfo, SymbolKind, SymbolProvenance, make_symbol_id, parse_symbol_id,
6};
7use crate::db::IndexDb;
8use crate::project::ProjectRoot;
9use anyhow::Result;
10use std::fs;
11
12impl SymbolIndex {
13 pub(super) fn select_solve_symbols_cached(
21 &self,
22 query: &str,
23 depth: usize,
24 ) -> Result<Vec<SymbolInfo>> {
25 let query_lower = query.to_ascii_lowercase();
26 let query_tokens: Vec<&str> = query_lower
27 .split(|c: char| c.is_whitespace() || c == '_' || c == '-')
28 .filter(|t| t.len() >= 3)
29 .collect();
30
31 let (top_files, importer_files) = {
36 let db = self.reader()?;
37 let all_paths = db.all_file_paths()?;
38
39 let mut file_scores: Vec<(String, usize)> = all_paths
40 .iter()
41 .map(|path| {
42 let path_lower = path.to_ascii_lowercase();
43 let score = query_tokens
44 .iter()
45 .filter(|t| path_lower.contains(**t))
46 .count();
47 (path.clone(), score)
48 })
49 .collect();
50
51 file_scores.sort_by_key(|b| std::cmp::Reverse(b.1));
52 let top: Vec<String> = file_scores
53 .iter()
54 .filter(|(_, score)| *score > 0)
55 .take(10)
56 .map(|(path, _)| path.clone())
57 .collect();
58
59 let mut importers = Vec::new();
61 if !top.is_empty() && top.len() <= 5 {
62 for file_path in top.iter().take(3) {
63 if let Ok(imp) = db.get_importers(file_path) {
64 for importer_path in imp.into_iter().take(3) {
65 importers.push(importer_path);
66 }
67 }
68 }
69 }
70
71 (top, importers)
72 };
74
75 let mut seen_ids = std::collections::HashSet::new();
76 let mut all_symbols = Vec::new();
77
78 for file_path in &top_files {
80 if let Ok(symbols) = self.get_symbols_overview_cached(file_path, depth) {
81 for sym in symbols {
82 if seen_ids.insert(sym.id.clone()) {
83 all_symbols.push(sym);
84 }
85 }
86 }
87 }
88
89 if let Ok(direct) = self.find_symbol_cached(query, None, false, false, 50) {
91 for sym in direct {
92 if seen_ids.insert(sym.id.clone()) {
93 all_symbols.push(sym);
94 }
95 }
96 }
97
98 for importer_path in &importer_files {
100 if let Ok(symbols) = self.get_symbols_overview_cached(importer_path, 1) {
101 for sym in symbols {
102 if seen_ids.insert(sym.id.clone()) {
103 all_symbols.push(sym);
104 }
105 }
106 }
107 }
108
109 if query_tokens.len() >= 2 {
111 for token in &query_tokens {
112 if let Ok(hits) = self.find_symbol_cached(token, None, false, false, 10) {
113 for sym in hits {
114 if seen_ids.insert(sym.id.clone()) {
115 all_symbols.push(sym);
116 }
117 }
118 }
119 }
120 }
121
122 if all_symbols.is_empty() {
124 return self.find_symbol_cached(query, None, false, false, 500);
125 }
126
127 Ok(all_symbols)
128 }
129
130 pub fn find_symbol_cached(
132 &self,
133 name: &str,
134 file_path: Option<&str>,
135 include_body: bool,
136 exact_match: bool,
137 max_matches: usize,
138 ) -> Result<Vec<SymbolInfo>> {
139 let db = self.reader()?;
140 if let Some((id_file, _id_kind, id_name_path)) = parse_symbol_id(name) {
142 let resolved = self.project.resolve(id_file)?;
143 let relative = self.project.to_relative(&resolved);
144 let db_rows = db.find_symbols_by_name_path(&relative, id_name_path, max_matches)?;
145 return Self::rows_to_symbol_infos(&self.project, &db, db_rows, include_body);
146 }
147
148 let resolved_fp = file_path.and_then(|fp| {
150 self.project
151 .resolve(fp)
152 .ok()
153 .map(|abs| self.project.to_relative(&abs))
154 });
155 let fp_ref = resolved_fp.as_deref().or(file_path);
156
157 let db_rows = db.find_symbols_by_name(name, fp_ref, exact_match, max_matches)?;
158 Self::rows_to_symbol_infos(&self.project, &db, db_rows, include_body)
159 }
160
161 pub fn file_is_indexed(&self, path: &str) -> Result<bool> {
167 let db = self.reader()?;
168 let resolved = self.project.resolve(path)?;
169 if resolved.is_dir() {
170 return Ok(false);
171 }
172 let relative = self.project.to_relative(&resolved);
173 Ok(db.get_file(&relative)?.is_some())
174 }
175
176 pub fn get_symbols_overview_cached(
178 &self,
179 path: &str,
180 _depth: usize,
181 ) -> Result<Vec<SymbolInfo>> {
182 let db = self.reader()?;
183 let resolved = self.project.resolve(path)?;
184 if resolved.is_dir() {
185 let prefix = self.project.to_relative(&resolved);
186 let file_groups = db.get_symbols_for_directory(&prefix)?;
188 let mut symbols = Vec::new();
189 for (rel, file_symbols) in file_groups {
190 if file_symbols.is_empty() {
191 continue;
192 }
193 let id = make_symbol_id(&rel, &SymbolKind::File, &rel);
194 symbols.push(SymbolInfo {
195 name: rel.clone(),
196 kind: SymbolKind::File,
197 file_path: rel.clone(),
198 line: 0,
199 column: 0,
200 signature: format!(
201 "{} ({} symbols)",
202 std::path::Path::new(&rel)
203 .file_name()
204 .and_then(|n| n.to_str())
205 .unwrap_or(&rel),
206 file_symbols.len()
207 ),
208 name_path: rel.clone(),
209 id,
210 provenance: SymbolProvenance::from_path(&rel),
211 body: None,
212 children: file_symbols
213 .into_iter()
214 .map(|row| {
215 let kind = SymbolKind::from_str_label(&row.kind);
216 let sid = make_symbol_id(&rel, &kind, &row.name_path);
217 SymbolInfo {
218 name: row.name,
219 kind,
220 file_path: rel.clone(),
221 line: row.line as usize,
222 column: row.column_num as usize,
223 signature: row.signature,
224 name_path: row.name_path,
225 id: sid,
226 provenance: SymbolProvenance::from_path(&rel),
227 body: None,
228 children: Vec::new(),
229 start_byte: row.start_byte as u32,
230 end_byte: row.end_byte as u32,
231 }
232 })
233 .collect(),
234 start_byte: 0,
235 end_byte: 0,
236 });
237 }
238 return Ok(symbols);
239 }
240
241 let relative = self.project.to_relative(&resolved);
243 let file_row = match db.get_file(&relative)? {
244 Some(row) => row,
245 None => return Ok(Vec::new()),
246 };
247 let db_symbols = db.get_file_symbols(file_row.id)?;
248 Ok(db_symbols
249 .into_iter()
250 .map(|row| {
251 let kind = SymbolKind::from_str_label(&row.kind);
252 let id = make_symbol_id(&relative, &kind, &row.name_path);
253 SymbolInfo {
254 name: row.name,
255 kind,
256 file_path: relative.clone(),
257 provenance: SymbolProvenance::from_path(&relative),
258 line: row.line as usize,
259 column: row.column_num as usize,
260 signature: row.signature,
261 name_path: row.name_path,
262 id,
263 body: None,
264 children: Vec::new(),
265 start_byte: row.start_byte as u32,
266 end_byte: row.end_byte as u32,
267 }
268 })
269 .collect())
270 }
271
272 #[allow(clippy::too_many_arguments)]
276 pub fn get_ranked_context_cached(
277 &self,
278 query: &str,
279 path: Option<&str>,
280 max_tokens: usize,
281 include_body: bool,
282 depth: usize,
283 graph_cache: Option<&crate::import_graph::GraphCache>,
284 semantic_scores: std::collections::HashMap<String, f64>,
285 ) -> Result<RankedContextResult> {
286 self.get_ranked_context_cached_with_query_type(
287 query,
288 path,
289 max_tokens,
290 include_body,
291 depth,
292 graph_cache,
293 semantic_scores,
294 None,
295 )
296 }
297
298 #[allow(clippy::too_many_arguments)]
302 pub fn get_ranked_context_cached_with_query_type(
303 &self,
304 query: &str,
305 path: Option<&str>,
306 max_tokens: usize,
307 include_body: bool,
308 depth: usize,
309 graph_cache: Option<&crate::import_graph::GraphCache>,
310 semantic_scores: std::collections::HashMap<String, f64>,
311 query_type: Option<&str>,
312 ) -> Result<RankedContextResult> {
313 let all_symbols = if let Some(path) = path {
314 self.get_symbols_overview_cached(path, depth)?
315 } else {
316 self.select_solve_symbols_cached(query, depth)?
317 };
318
319 let ranking_ctx = match graph_cache {
320 Some(gc) => {
321 let pagerank_arc = gc.file_pagerank_scores(&self.project);
322 let pagerank = (*pagerank_arc).clone();
323 if semantic_scores.is_empty() {
324 RankingContext::with_pagerank(pagerank)
325 } else {
326 RankingContext::with_pagerank_and_semantic(query, pagerank, semantic_scores)
327 }
328 }
329 None => {
330 if semantic_scores.is_empty() {
331 RankingContext::text_only()
332 } else {
333 RankingContext::with_pagerank_and_semantic(
334 query,
335 std::collections::HashMap::new(),
336 semantic_scores,
337 )
338 }
339 }
340 };
341
342 let ranking_ctx = if let Some(qt) = query_type {
344 let mut ctx = ranking_ctx;
345 ctx.weights = ranking::weights_for_query_type(qt);
346 ctx
347 } else {
348 ranking_ctx
349 };
350
351 let flat_symbols: Vec<SymbolInfo> = all_symbols
352 .into_iter()
353 .flat_map(flatten_symbol_infos)
354 .collect();
355
356 let scored = rank_symbols(query, flat_symbols, &ranking_ctx);
357
358 let (selected, chars_used) =
359 prune_to_budget(scored, max_tokens, include_body, self.project.as_path());
360
361 Ok(RankedContextResult {
362 query: query.to_owned(),
363 count: selected.len(),
364 symbols: selected,
365 token_budget: max_tokens,
366 chars_used,
367 })
368 }
369
370 pub(super) fn rows_to_symbol_infos(
373 project: &ProjectRoot,
374 db: &IndexDb,
375 rows: Vec<crate::db::SymbolRow>,
376 include_body: bool,
377 ) -> Result<Vec<SymbolInfo>> {
378 let mut results = Vec::new();
379 let mut path_cache: std::collections::HashMap<i64, String> =
380 std::collections::HashMap::new();
381 for row in rows {
382 let rel_path = match path_cache.get(&row.file_id) {
383 Some(p) => p.clone(),
384 None => {
385 let p = db.get_file_path(row.file_id)?.unwrap_or_default();
386 path_cache.insert(row.file_id, p.clone());
387 p
388 }
389 };
390 let body = if include_body {
391 let abs = project.as_path().join(&rel_path);
392 fs::read_to_string(&abs)
393 .ok()
394 .map(|source| slice_source(&source, row.start_byte as u32, row.end_byte as u32))
395 } else {
396 None
397 };
398 let kind = SymbolKind::from_str_label(&row.kind);
399 let id = make_symbol_id(&rel_path, &kind, &row.name_path);
400 results.push(SymbolInfo {
401 name: row.name,
402 kind,
403 provenance: SymbolProvenance::from_path(&rel_path),
404 file_path: rel_path,
405 line: row.line as usize,
406 column: row.column_num as usize,
407 signature: row.signature,
408 name_path: row.name_path,
409 id,
410 body,
411 children: Vec::new(),
412 start_byte: row.start_byte as u32,
413 end_byte: row.end_byte as u32,
414 });
415 }
416 Ok(results)
417 }
418}