1use super::SymbolIndex;
2use super::parser::{flatten_symbol_infos, slice_source};
3use super::ranking::{self, RankingContext, prune_to_budget, rank_symbols};
4use super::types::{
5 RankedContextResult, SymbolInfo, SymbolKind, SymbolProvenance, make_symbol_id, parse_symbol_id,
6};
7use crate::db::IndexDb;
8use crate::project::ProjectRoot;
9use anyhow::Result;
10use std::fs;
11
12impl SymbolIndex {
13 pub(super) fn select_solve_symbols_cached(
21 &self,
22 query: &str,
23 depth: usize,
24 ) -> Result<Vec<SymbolInfo>> {
25 let query_lower = query.to_ascii_lowercase();
26 let query_tokens: Vec<&str> = query_lower
27 .split(|c: char| c.is_whitespace() || c == '_' || c == '-')
28 .filter(|t| t.len() >= 3)
29 .collect();
30
31 let (top_files, importer_files) = {
36 let db = self.reader()?;
37 let all_paths = db.all_file_paths()?;
38
39 let mut file_scores: Vec<(String, usize)> = all_paths
40 .iter()
41 .map(|path| {
42 let path_lower = path.to_ascii_lowercase();
43 let score = query_tokens
44 .iter()
45 .filter(|t| path_lower.contains(**t))
46 .count();
47 (path.clone(), score)
48 })
49 .collect();
50
51 file_scores.sort_by(|a, b| b.1.cmp(&a.1));
52 let top: Vec<String> = file_scores
53 .iter()
54 .filter(|(_, score)| *score > 0)
55 .take(10)
56 .map(|(path, _)| path.clone())
57 .collect();
58
59 let mut importers = Vec::new();
61 if !top.is_empty() && top.len() <= 5 {
62 for file_path in top.iter().take(3) {
63 if let Ok(imp) = db.get_importers(file_path) {
64 for importer_path in imp.into_iter().take(3) {
65 importers.push(importer_path);
66 }
67 }
68 }
69 }
70
71 (top, importers)
72 };
74
75 let mut seen_ids = std::collections::HashSet::new();
76 let mut all_symbols = Vec::new();
77
78 for file_path in &top_files {
80 if let Ok(symbols) = self.get_symbols_overview_cached(file_path, depth) {
81 for sym in symbols {
82 if seen_ids.insert(sym.id.clone()) {
83 all_symbols.push(sym);
84 }
85 }
86 }
87 }
88
89 if let Ok(direct) = self.find_symbol_cached(query, None, false, false, 50) {
91 for sym in direct {
92 if seen_ids.insert(sym.id.clone()) {
93 all_symbols.push(sym);
94 }
95 }
96 }
97
98 for importer_path in &importer_files {
100 if let Ok(symbols) = self.get_symbols_overview_cached(importer_path, 1) {
101 for sym in symbols {
102 if seen_ids.insert(sym.id.clone()) {
103 all_symbols.push(sym);
104 }
105 }
106 }
107 }
108
109 if query_tokens.len() >= 2 {
111 for token in &query_tokens {
112 if let Ok(hits) = self.find_symbol_cached(token, None, false, false, 10) {
113 for sym in hits {
114 if seen_ids.insert(sym.id.clone()) {
115 all_symbols.push(sym);
116 }
117 }
118 }
119 }
120 }
121
122 if all_symbols.is_empty() {
124 return self.find_symbol_cached(query, None, false, false, 500);
125 }
126
127 Ok(all_symbols)
128 }
129
130 pub fn find_symbol_cached(
132 &self,
133 name: &str,
134 file_path: Option<&str>,
135 include_body: bool,
136 exact_match: bool,
137 max_matches: usize,
138 ) -> Result<Vec<SymbolInfo>> {
139 let db = self.reader()?;
140 if let Some((id_file, _id_kind, id_name_path)) = parse_symbol_id(name) {
142 let leaf_name = id_name_path.rsplit('/').next().unwrap_or(id_name_path);
143 let db_rows = db.find_symbols_by_name(leaf_name, Some(id_file), true, max_matches)?;
144 return Self::rows_to_symbol_infos(&self.project, &db, db_rows, include_body);
145 }
146
147 let resolved_fp = file_path.and_then(|fp| {
149 self.project
150 .resolve(fp)
151 .ok()
152 .map(|abs| self.project.to_relative(&abs))
153 });
154 let fp_ref = resolved_fp.as_deref().or(file_path);
155
156 let db_rows = db.find_symbols_by_name(name, fp_ref, exact_match, max_matches)?;
157 Self::rows_to_symbol_infos(&self.project, &db, db_rows, include_body)
158 }
159
160 pub fn get_symbols_overview_cached(
162 &self,
163 path: &str,
164 _depth: usize,
165 ) -> Result<Vec<SymbolInfo>> {
166 let db = self.reader()?;
167 let resolved = self.project.resolve(path)?;
168 if resolved.is_dir() {
169 let prefix = self.project.to_relative(&resolved);
170 let file_groups = db.get_symbols_for_directory(&prefix)?;
172 let mut symbols = Vec::new();
173 for (rel, file_symbols) in file_groups {
174 if file_symbols.is_empty() {
175 continue;
176 }
177 let id = make_symbol_id(&rel, &SymbolKind::File, &rel);
178 symbols.push(SymbolInfo {
179 name: rel.clone(),
180 kind: SymbolKind::File,
181 file_path: rel.clone(),
182 line: 0,
183 column: 0,
184 signature: format!(
185 "{} ({} symbols)",
186 std::path::Path::new(&rel)
187 .file_name()
188 .and_then(|n| n.to_str())
189 .unwrap_or(&rel),
190 file_symbols.len()
191 ),
192 name_path: rel.clone(),
193 id,
194 provenance: SymbolProvenance::from_path(&rel),
195 body: None,
196 children: file_symbols
197 .into_iter()
198 .map(|row| {
199 let kind = SymbolKind::from_str_label(&row.kind);
200 let sid = make_symbol_id(&rel, &kind, &row.name_path);
201 SymbolInfo {
202 name: row.name,
203 kind,
204 file_path: rel.clone(),
205 line: row.line as usize,
206 column: row.column_num as usize,
207 signature: row.signature,
208 name_path: row.name_path,
209 id: sid,
210 provenance: SymbolProvenance::from_path(&rel),
211 body: None,
212 children: Vec::new(),
213 start_byte: row.start_byte as u32,
214 end_byte: row.end_byte as u32,
215 }
216 })
217 .collect(),
218 start_byte: 0,
219 end_byte: 0,
220 });
221 }
222 return Ok(symbols);
223 }
224
225 let relative = self.project.to_relative(&resolved);
227 let file_row = match db.get_file(&relative)? {
228 Some(row) => row,
229 None => return Ok(Vec::new()),
230 };
231 let db_symbols = db.get_file_symbols(file_row.id)?;
232 Ok(db_symbols
233 .into_iter()
234 .map(|row| {
235 let kind = SymbolKind::from_str_label(&row.kind);
236 let id = make_symbol_id(&relative, &kind, &row.name_path);
237 SymbolInfo {
238 name: row.name,
239 kind,
240 file_path: relative.clone(),
241 provenance: SymbolProvenance::from_path(&relative),
242 line: row.line as usize,
243 column: row.column_num as usize,
244 signature: row.signature,
245 name_path: row.name_path,
246 id,
247 body: None,
248 children: Vec::new(),
249 start_byte: row.start_byte as u32,
250 end_byte: row.end_byte as u32,
251 }
252 })
253 .collect())
254 }
255
256 #[allow(clippy::too_many_arguments)]
260 pub fn get_ranked_context_cached(
261 &self,
262 query: &str,
263 path: Option<&str>,
264 max_tokens: usize,
265 include_body: bool,
266 depth: usize,
267 graph_cache: Option<&crate::import_graph::GraphCache>,
268 semantic_scores: std::collections::HashMap<String, f64>,
269 ) -> Result<RankedContextResult> {
270 self.get_ranked_context_cached_with_query_type(
271 query,
272 path,
273 max_tokens,
274 include_body,
275 depth,
276 graph_cache,
277 semantic_scores,
278 None,
279 )
280 }
281
282 pub fn get_ranked_context_cached_with_query_type(
286 &self,
287 query: &str,
288 path: Option<&str>,
289 max_tokens: usize,
290 include_body: bool,
291 depth: usize,
292 graph_cache: Option<&crate::import_graph::GraphCache>,
293 semantic_scores: std::collections::HashMap<String, f64>,
294 query_type: Option<&str>,
295 ) -> Result<RankedContextResult> {
296 let all_symbols = if let Some(path) = path {
297 self.get_symbols_overview_cached(path, depth)?
298 } else {
299 self.select_solve_symbols_cached(query, depth)?
300 };
301
302 let ranking_ctx = match graph_cache {
303 Some(gc) => {
304 let pagerank = gc.file_pagerank_scores(&self.project);
305 if semantic_scores.is_empty() {
306 RankingContext::with_pagerank(pagerank)
307 } else {
308 RankingContext::with_pagerank_and_semantic(query, pagerank, semantic_scores)
309 }
310 }
311 None => {
312 if semantic_scores.is_empty() {
313 RankingContext::text_only()
314 } else {
315 RankingContext::with_pagerank_and_semantic(
316 query,
317 std::collections::HashMap::new(),
318 semantic_scores,
319 )
320 }
321 }
322 };
323
324 let ranking_ctx = if let Some(qt) = query_type {
326 let mut ctx = ranking_ctx;
327 ctx.weights = ranking::weights_for_query_type(qt);
328 ctx
329 } else {
330 ranking_ctx
331 };
332
333 let flat_symbols: Vec<SymbolInfo> = all_symbols
334 .into_iter()
335 .flat_map(flatten_symbol_infos)
336 .collect();
337
338 let scored = rank_symbols(query, flat_symbols, &ranking_ctx);
339
340 let (selected, chars_used) =
341 prune_to_budget(scored, max_tokens, include_body, self.project.as_path());
342
343 Ok(RankedContextResult {
344 query: query.to_owned(),
345 count: selected.len(),
346 symbols: selected,
347 token_budget: max_tokens,
348 chars_used,
349 })
350 }
351
352 pub(super) fn rows_to_symbol_infos(
355 project: &ProjectRoot,
356 db: &IndexDb,
357 rows: Vec<crate::db::SymbolRow>,
358 include_body: bool,
359 ) -> Result<Vec<SymbolInfo>> {
360 let mut results = Vec::new();
361 let mut path_cache: std::collections::HashMap<i64, String> =
362 std::collections::HashMap::new();
363 for row in rows {
364 let rel_path = match path_cache.get(&row.file_id) {
365 Some(p) => p.clone(),
366 None => {
367 let p = db.get_file_path(row.file_id)?.unwrap_or_default();
368 path_cache.insert(row.file_id, p.clone());
369 p
370 }
371 };
372 let body = if include_body {
373 let abs = project.as_path().join(&rel_path);
374 fs::read_to_string(&abs)
375 .ok()
376 .map(|source| slice_source(&source, row.start_byte as u32, row.end_byte as u32))
377 } else {
378 None
379 };
380 let kind = SymbolKind::from_str_label(&row.kind);
381 let id = make_symbol_id(&rel_path, &kind, &row.name_path);
382 results.push(SymbolInfo {
383 name: row.name,
384 kind,
385 provenance: SymbolProvenance::from_path(&rel_path),
386 file_path: rel_path,
387 line: row.line as usize,
388 column: row.column_num as usize,
389 signature: row.signature,
390 name_path: row.name_path,
391 id,
392 body,
393 children: Vec::new(),
394 start_byte: row.start_byte as u32,
395 end_byte: row.end_byte as u32,
396 });
397 }
398 Ok(results)
399 }
400}