Skip to main content

codelens_engine/symbols/
reader.rs

1use super::SymbolIndex;
2use super::parser::{flatten_symbol_infos, slice_source};
3use super::ranking::{self, RankingContext, prune_to_budget, rank_symbols};
4use super::types::{
5    RankedContextResult, SymbolInfo, SymbolKind, SymbolProvenance, make_symbol_id, parse_symbol_id,
6};
7use crate::db::IndexDb;
8use crate::project::ProjectRoot;
9use anyhow::Result;
10use std::fs;
11
12impl SymbolIndex {
13    /// Hybrid candidate collection: fan-out to multiple retrieval paths,
14    /// then merge and deduplicate. Returns a broad candidate pool for ranking.
15    ///
16    /// Retrieval paths:
17    /// 1. File path token matching — top files whose path contains query tokens
18    /// 2. Direct symbol name matching — exact/substring DB lookup
19    /// 3. Import graph proximity — files that import/are imported by matched files
20    pub(super) fn select_solve_symbols_cached(
21        &self,
22        query: &str,
23        depth: usize,
24    ) -> Result<Vec<SymbolInfo>> {
25        let query_lower = query.to_ascii_lowercase();
26        let query_tokens: Vec<&str> = query_lower
27            .split(|c: char| c.is_whitespace() || c == '_' || c == '-')
28            .filter(|t| t.len() >= 3)
29            .collect();
30
31        // Compute file scores and import-graph proximity inside a block so the
32        // ReadDb guard is dropped before calling find_symbol_cached /
33        // get_symbols_overview_cached, which also acquire the reader lock.
34        // Holding both causes a deadlock when in_memory=true (same Mutex).
35        let (top_files, importer_files) = {
36            let db = self.reader()?;
37            let all_paths = db.all_file_paths()?;
38
39            let mut file_scores: Vec<(String, usize)> = all_paths
40                .iter()
41                .map(|path| {
42                    let path_lower = path.to_ascii_lowercase();
43                    let score = query_tokens
44                        .iter()
45                        .filter(|t| path_lower.contains(**t))
46                        .count();
47                    (path.clone(), score)
48                })
49                .collect();
50
51            file_scores.sort_by(|a, b| b.1.cmp(&a.1));
52            let top: Vec<String> = file_scores
53                .iter()
54                .filter(|(_, score)| *score > 0)
55                .take(10)
56                .map(|(path, _)| path.clone())
57                .collect();
58
59            // Path 3: import graph proximity
60            let mut importers = Vec::new();
61            if !top.is_empty() && top.len() <= 5 {
62                for file_path in top.iter().take(3) {
63                    if let Ok(imp) = db.get_importers(file_path) {
64                        for importer_path in imp.into_iter().take(3) {
65                            importers.push(importer_path);
66                        }
67                    }
68                }
69            }
70
71            (top, importers)
72            // db dropped here
73        };
74
75        let mut seen_ids = std::collections::HashSet::new();
76        let mut all_symbols = Vec::new();
77
78        // Path 1: collect symbols from path-matched files
79        for file_path in &top_files {
80            if let Ok(symbols) = self.get_symbols_overview_cached(file_path, depth) {
81                for sym in symbols {
82                    if seen_ids.insert(sym.id.clone()) {
83                        all_symbols.push(sym);
84                    }
85                }
86            }
87        }
88
89        // Path 2: direct symbol name matching
90        if let Ok(direct) = self.find_symbol_cached(query, None, false, false, 50) {
91            for sym in direct {
92                if seen_ids.insert(sym.id.clone()) {
93                    all_symbols.push(sym);
94                }
95            }
96        }
97
98        // Path 3: import graph proximity — related code via structural connection
99        for importer_path in &importer_files {
100            if let Ok(symbols) = self.get_symbols_overview_cached(importer_path, 1) {
101                for sym in symbols {
102                    if seen_ids.insert(sym.id.clone()) {
103                        all_symbols.push(sym);
104                    }
105                }
106            }
107        }
108
109        // Path 4: for multi-word queries, search individual tokens as symbol names
110        if query_tokens.len() >= 2 {
111            for token in &query_tokens {
112                if let Ok(hits) = self.find_symbol_cached(token, None, false, false, 10) {
113                    for sym in hits {
114                        if seen_ids.insert(sym.id.clone()) {
115                            all_symbols.push(sym);
116                        }
117                    }
118                }
119            }
120        }
121
122        // Fallback: if no candidates found, do a broad symbol search
123        if all_symbols.is_empty() {
124            return self.find_symbol_cached(query, None, false, false, 500);
125        }
126
127        Ok(all_symbols)
128    }
129
130    /// Query symbols from DB without lazy indexing. Returns empty if file not yet indexed.
131    pub fn find_symbol_cached(
132        &self,
133        name: &str,
134        file_path: Option<&str>,
135        include_body: bool,
136        exact_match: bool,
137        max_matches: usize,
138    ) -> Result<Vec<SymbolInfo>> {
139        let db = self.reader()?;
140        // Stable ID fast path
141        if let Some((id_file, _id_kind, id_name_path)) = parse_symbol_id(name) {
142            let leaf_name = id_name_path.rsplit('/').next().unwrap_or(id_name_path);
143            let db_rows = db.find_symbols_by_name(leaf_name, Some(id_file), true, max_matches)?;
144            return Self::rows_to_symbol_infos(&self.project, &db, db_rows, include_body);
145        }
146
147        // Resolve file_path (handles symlinks → canonical relative path)
148        let resolved_fp = file_path.and_then(|fp| {
149            self.project
150                .resolve(fp)
151                .ok()
152                .map(|abs| self.project.to_relative(&abs))
153        });
154        let fp_ref = resolved_fp.as_deref().or(file_path);
155
156        let db_rows = db.find_symbols_by_name(name, fp_ref, exact_match, max_matches)?;
157        Self::rows_to_symbol_infos(&self.project, &db, db_rows, include_body)
158    }
159
160    /// Get symbols overview from DB without lazy indexing.
161    pub fn get_symbols_overview_cached(
162        &self,
163        path: &str,
164        _depth: usize,
165    ) -> Result<Vec<SymbolInfo>> {
166        let db = self.reader()?;
167        let resolved = self.project.resolve(path)?;
168        if resolved.is_dir() {
169            let prefix = self.project.to_relative(&resolved);
170            // Single JOIN query instead of N+1 (all_file_paths + get_file + get_file_symbols per file)
171            let file_groups = db.get_symbols_for_directory(&prefix)?;
172            let mut symbols = Vec::new();
173            for (rel, file_symbols) in file_groups {
174                if file_symbols.is_empty() {
175                    continue;
176                }
177                let id = make_symbol_id(&rel, &SymbolKind::File, &rel);
178                symbols.push(SymbolInfo {
179                    name: rel.clone(),
180                    kind: SymbolKind::File,
181                    file_path: rel.clone(),
182                    line: 0,
183                    column: 0,
184                    signature: format!(
185                        "{} ({} symbols)",
186                        std::path::Path::new(&rel)
187                            .file_name()
188                            .and_then(|n| n.to_str())
189                            .unwrap_or(&rel),
190                        file_symbols.len()
191                    ),
192                    name_path: rel.clone(),
193                    id,
194                    provenance: SymbolProvenance::from_path(&rel),
195                    body: None,
196                    children: file_symbols
197                        .into_iter()
198                        .map(|row| {
199                            let kind = SymbolKind::from_str_label(&row.kind);
200                            let sid = make_symbol_id(&rel, &kind, &row.name_path);
201                            SymbolInfo {
202                                name: row.name,
203                                kind,
204                                file_path: rel.clone(),
205                                line: row.line as usize,
206                                column: row.column_num as usize,
207                                signature: row.signature,
208                                name_path: row.name_path,
209                                id: sid,
210                                provenance: SymbolProvenance::from_path(&rel),
211                                body: None,
212                                children: Vec::new(),
213                                start_byte: row.start_byte as u32,
214                                end_byte: row.end_byte as u32,
215                            }
216                        })
217                        .collect(),
218                    start_byte: 0,
219                    end_byte: 0,
220                });
221            }
222            return Ok(symbols);
223        }
224
225        // Single file
226        let relative = self.project.to_relative(&resolved);
227        let file_row = match db.get_file(&relative)? {
228            Some(row) => row,
229            None => return Ok(Vec::new()),
230        };
231        let db_symbols = db.get_file_symbols(file_row.id)?;
232        Ok(db_symbols
233            .into_iter()
234            .map(|row| {
235                let kind = SymbolKind::from_str_label(&row.kind);
236                let id = make_symbol_id(&relative, &kind, &row.name_path);
237                SymbolInfo {
238                    name: row.name,
239                    kind,
240                    file_path: relative.clone(),
241                    provenance: SymbolProvenance::from_path(&relative),
242                    line: row.line as usize,
243                    column: row.column_num as usize,
244                    signature: row.signature,
245                    name_path: row.name_path,
246                    id,
247                    body: None,
248                    children: Vec::new(),
249                    start_byte: row.start_byte as u32,
250                    end_byte: row.end_byte as u32,
251                }
252            })
253            .collect())
254    }
255
256    /// Ranked context from DB without lazy indexing.
257    /// If `graph_cache` is provided, PageRank scores boost symbols in highly-imported files.
258    /// If `semantic_scores` is non-empty, vector similarity is blended into ranking.
259    #[allow(clippy::too_many_arguments)]
260    pub fn get_ranked_context_cached(
261        &self,
262        query: &str,
263        path: Option<&str>,
264        max_tokens: usize,
265        include_body: bool,
266        depth: usize,
267        graph_cache: Option<&crate::import_graph::GraphCache>,
268        semantic_scores: std::collections::HashMap<String, f64>,
269    ) -> Result<RankedContextResult> {
270        self.get_ranked_context_cached_with_query_type(
271            query,
272            path,
273            max_tokens,
274            include_body,
275            depth,
276            graph_cache,
277            semantic_scores,
278            None,
279        )
280    }
281
282    /// Like `get_ranked_context_cached` but accepts an optional query type
283    /// (`"identifier"`, `"natural_language"`, `"short_phrase"`) to tune
284    /// ranking weights per query category.
285    pub fn get_ranked_context_cached_with_query_type(
286        &self,
287        query: &str,
288        path: Option<&str>,
289        max_tokens: usize,
290        include_body: bool,
291        depth: usize,
292        graph_cache: Option<&crate::import_graph::GraphCache>,
293        semantic_scores: std::collections::HashMap<String, f64>,
294        query_type: Option<&str>,
295    ) -> Result<RankedContextResult> {
296        let all_symbols = if let Some(path) = path {
297            self.get_symbols_overview_cached(path, depth)?
298        } else {
299            self.select_solve_symbols_cached(query, depth)?
300        };
301
302        let ranking_ctx = match graph_cache {
303            Some(gc) => {
304                let pagerank = gc.file_pagerank_scores(&self.project);
305                if semantic_scores.is_empty() {
306                    RankingContext::with_pagerank(pagerank)
307                } else {
308                    RankingContext::with_pagerank_and_semantic(query, pagerank, semantic_scores)
309                }
310            }
311            None => {
312                if semantic_scores.is_empty() {
313                    RankingContext::text_only()
314                } else {
315                    RankingContext::with_pagerank_and_semantic(
316                        query,
317                        std::collections::HashMap::new(),
318                        semantic_scores,
319                    )
320                }
321            }
322        };
323
324        // Apply query-type-aware weights when specified.
325        let ranking_ctx = if let Some(qt) = query_type {
326            let mut ctx = ranking_ctx;
327            ctx.weights = ranking::weights_for_query_type(qt);
328            ctx
329        } else {
330            ranking_ctx
331        };
332
333        let flat_symbols: Vec<SymbolInfo> = all_symbols
334            .into_iter()
335            .flat_map(flatten_symbol_infos)
336            .collect();
337
338        let scored = rank_symbols(query, flat_symbols, &ranking_ctx);
339
340        let (selected, chars_used) =
341            prune_to_budget(scored, max_tokens, include_body, self.project.as_path());
342
343        Ok(RankedContextResult {
344            query: query.to_owned(),
345            count: selected.len(),
346            symbols: selected,
347            token_budget: max_tokens,
348            chars_used,
349        })
350    }
351
352    /// Helper: convert DB rows to SymbolInfo with optional body.
353    /// Uses a file_id→path cache to avoid N+1 `get_file_path` queries.
354    pub(super) fn rows_to_symbol_infos(
355        project: &ProjectRoot,
356        db: &IndexDb,
357        rows: Vec<crate::db::SymbolRow>,
358        include_body: bool,
359    ) -> Result<Vec<SymbolInfo>> {
360        let mut results = Vec::new();
361        let mut path_cache: std::collections::HashMap<i64, String> =
362            std::collections::HashMap::new();
363        for row in rows {
364            let rel_path = match path_cache.get(&row.file_id) {
365                Some(p) => p.clone(),
366                None => {
367                    let p = db.get_file_path(row.file_id)?.unwrap_or_default();
368                    path_cache.insert(row.file_id, p.clone());
369                    p
370                }
371            };
372            let body = if include_body {
373                let abs = project.as_path().join(&rel_path);
374                fs::read_to_string(&abs)
375                    .ok()
376                    .map(|source| slice_source(&source, row.start_byte as u32, row.end_byte as u32))
377            } else {
378                None
379            };
380            let kind = SymbolKind::from_str_label(&row.kind);
381            let id = make_symbol_id(&rel_path, &kind, &row.name_path);
382            results.push(SymbolInfo {
383                name: row.name,
384                kind,
385                provenance: SymbolProvenance::from_path(&rel_path),
386                file_path: rel_path,
387                line: row.line as usize,
388                column: row.column_num as usize,
389                signature: row.signature,
390                name_path: row.name_path,
391                id,
392                body,
393                children: Vec::new(),
394                start_byte: row.start_byte as u32,
395                end_byte: row.end_byte as u32,
396            });
397        }
398        Ok(results)
399    }
400}