Skip to main content

code_analyze_mcp/
analyze.rs

1//! Main analysis engine for extracting code structure from files and directories.
2//!
3//! Implements the four MCP tools: `analyze_directory` (Overview), `analyze_file` (FileDetails),
4//! `analyze_symbol` (call graph), and `analyze_module` (lightweight index). Handles parallel processing and cancellation.
5
6use crate::formatter::{
7    format_file_details, format_focused, format_focused_summary, format_structure,
8};
9use crate::graph::{CallGraph, InternalCallChain, resolve_symbol};
10use crate::lang::language_from_extension;
11use crate::parser::{ElementExtractor, SemanticExtractor, extract_impl_traits};
12use crate::test_detection::is_test_file;
13use crate::traversal::{WalkEntry, walk_directory};
14use crate::types::{
15    AnalysisMode, FileInfo, ImplTraitInfo, ImportInfo, SemanticAnalysis, SymbolMatchMode,
16};
17use rayon::prelude::*;
18use schemars::JsonSchema;
19use serde::Serialize;
20use std::path::{Path, PathBuf};
21use std::sync::Arc;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::time::Instant;
24use thiserror::Error;
25use tokio_util::sync::CancellationToken;
26use tracing::instrument;
27
28#[derive(Debug, Error)]
29pub enum AnalyzeError {
30    #[error("Traversal error: {0}")]
31    Traversal(#[from] crate::traversal::TraversalError),
32    #[error("Parser error: {0}")]
33    Parser(#[from] crate::parser::ParserError),
34    #[error("Graph error: {0}")]
35    Graph(#[from] crate::graph::GraphError),
36    #[error("Formatter error: {0}")]
37    Formatter(#[from] crate::formatter::FormatterError),
38    #[error("Analysis cancelled")]
39    Cancelled,
40}
41
42/// Result of directory analysis containing both formatted output and file data.
43#[derive(Debug, Clone, Serialize, JsonSchema)]
44pub struct AnalysisOutput {
45    #[schemars(description = "Formatted text representation of the analysis")]
46    pub formatted: String,
47    #[schemars(description = "List of files analyzed in the directory")]
48    pub files: Vec<FileInfo>,
49    /// Walk entries used internally for summary generation; not serialized.
50    #[serde(skip)]
51    #[schemars(skip)]
52    pub entries: Vec<WalkEntry>,
53    /// Subtree file counts computed from an unbounded walk; used by format_summary; not serialized.
54    #[serde(skip)]
55    #[schemars(skip)]
56    pub subtree_counts: Option<Vec<(std::path::PathBuf, usize)>>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    #[schemars(
59        description = "Opaque cursor token for the next page of results (absent when no more results)"
60    )]
61    pub next_cursor: Option<String>,
62}
63
64/// Result of file-level semantic analysis.
65#[derive(Debug, Clone, Serialize, JsonSchema)]
66pub struct FileAnalysisOutput {
67    #[schemars(description = "Formatted text representation of the analysis")]
68    pub formatted: String,
69    #[schemars(description = "Semantic analysis data including functions, classes, and imports")]
70    pub semantic: SemanticAnalysis,
71    #[schemars(description = "Total line count of the analyzed file")]
72    #[schemars(schema_with = "crate::schema_helpers::integer_schema")]
73    pub line_count: usize,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    #[schemars(
76        description = "Opaque cursor token for the next page of results (absent when no more results)"
77    )]
78    pub next_cursor: Option<String>,
79}
80
81/// Analyze a directory structure with progress tracking.
82#[instrument(skip_all, fields(path = %root.display()))]
83pub fn analyze_directory_with_progress(
84    root: &Path,
85    entries: Vec<WalkEntry>,
86    progress: Arc<AtomicUsize>,
87    ct: CancellationToken,
88) -> Result<AnalysisOutput, AnalyzeError> {
89    // Check if already cancelled
90    if ct.is_cancelled() {
91        return Err(AnalyzeError::Cancelled);
92    }
93
94    // Detect language from file extension
95    let file_entries: Vec<&WalkEntry> = entries.iter().filter(|e| !e.is_dir).collect();
96
97    let start = Instant::now();
98    tracing::debug!(file_count = file_entries.len(), root = %root.display(), "analysis start");
99
100    // Parallel analysis of files
101    let analysis_results: Vec<FileInfo> = file_entries
102        .par_iter()
103        .filter_map(|entry| {
104            // Check cancellation per file
105            if ct.is_cancelled() {
106                return None;
107            }
108
109            let path_str = entry.path.display().to_string();
110
111            // Detect language from extension
112            let ext = entry.path.extension().and_then(|e| e.to_str());
113
114            // Try to read file content
115            let source = match std::fs::read_to_string(&entry.path) {
116                Ok(content) => content,
117                Err(_) => {
118                    // Binary file or unreadable - exclude from output
119                    progress.fetch_add(1, Ordering::Relaxed);
120                    return None;
121                }
122            };
123
124            // Count lines
125            let line_count = source.lines().count();
126
127            // Detect language and extract counts
128            let (language, function_count, class_count) = if let Some(ext_str) = ext {
129                if let Some(lang) = language_from_extension(ext_str) {
130                    let lang_str = lang.to_string();
131                    match ElementExtractor::extract_with_depth(&source, &lang_str) {
132                        Ok((func_count, class_count)) => (lang_str, func_count, class_count),
133                        Err(_) => (lang_str, 0, 0),
134                    }
135                } else {
136                    ("unknown".to_string(), 0, 0)
137                }
138            } else {
139                ("unknown".to_string(), 0, 0)
140            };
141
142            progress.fetch_add(1, Ordering::Relaxed);
143
144            let is_test = is_test_file(&entry.path);
145
146            Some(FileInfo {
147                path: path_str,
148                line_count,
149                function_count,
150                class_count,
151                language,
152                is_test,
153            })
154        })
155        .collect();
156
157    // Check if cancelled after parallel processing
158    if ct.is_cancelled() {
159        return Err(AnalyzeError::Cancelled);
160    }
161
162    tracing::debug!(
163        file_count = file_entries.len(),
164        duration_ms = start.elapsed().as_millis() as u64,
165        "analysis complete"
166    );
167
168    // Format output
169    let formatted = format_structure(&entries, &analysis_results, None);
170
171    Ok(AnalysisOutput {
172        formatted,
173        files: analysis_results,
174        entries,
175        next_cursor: None,
176        subtree_counts: None,
177    })
178}
179
180/// Analyze a directory structure and return formatted output and file data.
181#[instrument(skip_all, fields(path = %root.display()))]
182pub fn analyze_directory(
183    root: &Path,
184    max_depth: Option<u32>,
185) -> Result<AnalysisOutput, AnalyzeError> {
186    let entries = walk_directory(root, max_depth)?;
187    let counter = Arc::new(AtomicUsize::new(0));
188    let ct = CancellationToken::new();
189    analyze_directory_with_progress(root, entries, counter, ct)
190}
191
192/// Determine analysis mode based on parameters and path.
193pub fn determine_mode(path: &str, focus: Option<&str>) -> AnalysisMode {
194    if focus.is_some() {
195        return AnalysisMode::SymbolFocus;
196    }
197
198    let path_obj = Path::new(path);
199    if path_obj.is_dir() {
200        AnalysisMode::Overview
201    } else {
202        AnalysisMode::FileDetails
203    }
204}
205
206/// Analyze a single file and return semantic analysis with formatted output.
207#[instrument(skip_all, fields(path))]
208pub fn analyze_file(
209    path: &str,
210    ast_recursion_limit: Option<usize>,
211) -> Result<FileAnalysisOutput, AnalyzeError> {
212    let start = Instant::now();
213    let source = std::fs::read_to_string(path)
214        .map_err(|e| AnalyzeError::Parser(crate::parser::ParserError::ParseError(e.to_string())))?;
215
216    let line_count = source.lines().count();
217
218    // Detect language from extension
219    let ext = Path::new(path)
220        .extension()
221        .and_then(|e| e.to_str())
222        .and_then(language_from_extension)
223        .map(|l| l.to_string())
224        .unwrap_or_else(|| "unknown".to_string());
225
226    // Extract semantic information
227    let mut semantic = SemanticExtractor::extract(&source, &ext, ast_recursion_limit)?;
228
229    // Populate the file path on references now that the path is known
230    for r in &mut semantic.references {
231        r.location = path.to_string();
232    }
233
234    // Resolve Python wildcard imports
235    if ext == "python" {
236        resolve_wildcard_imports(Path::new(path), &mut semantic.imports);
237    }
238
239    // Detect if this is a test file
240    let is_test = is_test_file(Path::new(path));
241
242    // Extract parent directory for relative path display
243    let parent_dir = Path::new(path).parent();
244
245    // Format output
246    let formatted = format_file_details(path, &semantic, line_count, is_test, parent_dir);
247
248    tracing::debug!(path = %path, language = %ext, functions = semantic.functions.len(), classes = semantic.classes.len(), imports = semantic.imports.len(), duration_ms = start.elapsed().as_millis() as u64, "file analysis complete");
249
250    Ok(FileAnalysisOutput {
251        formatted,
252        semantic,
253        line_count,
254        next_cursor: None,
255    })
256}
257
258/// Result of focused symbol analysis.
259#[derive(Debug, Serialize, JsonSchema)]
260pub struct FocusedAnalysisOutput {
261    #[schemars(description = "Formatted text representation of the call graph analysis")]
262    pub formatted: String,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    #[schemars(
265        description = "Opaque cursor token for the next page of results (absent when no more results)"
266    )]
267    pub next_cursor: Option<String>,
268    /// Production caller chains (partitioned from incoming chains, excluding test callers).
269    /// Not serialized; used for pagination in lib.rs.
270    #[serde(skip)]
271    #[schemars(skip)]
272    pub(crate) prod_chains: Vec<InternalCallChain>,
273    /// Test caller chains. Not serialized; used for pagination summary in lib.rs.
274    #[serde(skip)]
275    #[schemars(skip)]
276    pub(crate) test_chains: Vec<InternalCallChain>,
277    /// Outgoing (callee) chains. Not serialized; used for pagination in lib.rs.
278    #[serde(skip)]
279    #[schemars(skip)]
280    pub(crate) outgoing_chains: Vec<InternalCallChain>,
281    /// Number of definitions for the symbol. Not serialized; used for pagination headers.
282    #[serde(skip)]
283    #[schemars(skip)]
284    pub def_count: usize,
285    /// Total unique callers before impl_only filter. Not serialized; used for FILTER header.
286    #[serde(skip)]
287    #[schemars(skip)]
288    pub unfiltered_caller_count: usize,
289    /// Unique callers after impl_only filter. Not serialized; used for FILTER header.
290    #[serde(skip)]
291    #[schemars(skip)]
292    pub impl_trait_caller_count: usize,
293}
294
295/// Analyze a symbol's call graph across a directory with progress tracking.
296#[instrument(skip_all, fields(path = %root.display(), symbol = %focus))]
297#[allow(clippy::too_many_arguments)]
298pub fn analyze_focused_with_progress(
299    root: &Path,
300    focus: &str,
301    match_mode: SymbolMatchMode,
302    follow_depth: u32,
303    max_depth: Option<u32>,
304    ast_recursion_limit: Option<usize>,
305    progress: Arc<AtomicUsize>,
306    ct: CancellationToken,
307    use_summary: bool,
308    impl_only: Option<bool>,
309) -> Result<FocusedAnalysisOutput, AnalyzeError> {
310    #[allow(clippy::too_many_arguments)]
311    // Check if already cancelled
312    if ct.is_cancelled() {
313        return Err(AnalyzeError::Cancelled);
314    }
315
316    // Check if path is a file (hint to use directory)
317    if root.is_file() {
318        let formatted =
319            "Single-file focus not supported. Please provide a directory path for cross-file call graph analysis.\n"
320                .to_string();
321        return Ok(FocusedAnalysisOutput {
322            formatted,
323            next_cursor: None,
324            prod_chains: vec![],
325            test_chains: vec![],
326            outgoing_chains: vec![],
327            def_count: 0,
328            unfiltered_caller_count: 0,
329            impl_trait_caller_count: 0,
330        });
331    }
332
333    // Walk the directory
334    let entries = walk_directory(root, max_depth)?;
335
336    // Collect semantic analysis for all files in parallel
337    let file_entries: Vec<&WalkEntry> = entries.iter().filter(|e| !e.is_dir).collect();
338
339    let analysis_results: Vec<(PathBuf, SemanticAnalysis)> = file_entries
340        .par_iter()
341        .filter_map(|entry| {
342            // Check cancellation per file
343            if ct.is_cancelled() {
344                return None;
345            }
346
347            let ext = entry.path.extension().and_then(|e| e.to_str());
348
349            // Try to read file content
350            let source = match std::fs::read_to_string(&entry.path) {
351                Ok(content) => content,
352                Err(_) => {
353                    progress.fetch_add(1, Ordering::Relaxed);
354                    return None;
355                }
356            };
357
358            // Detect language and extract semantic information
359            let language = if let Some(ext_str) = ext {
360                language_from_extension(ext_str)
361                    .map(|l| l.to_string())
362                    .unwrap_or_else(|| "unknown".to_string())
363            } else {
364                "unknown".to_string()
365            };
366
367            match SemanticExtractor::extract(&source, &language, ast_recursion_limit) {
368                Ok(mut semantic) => {
369                    // Populate file path on references
370                    for r in &mut semantic.references {
371                        r.location = entry.path.display().to_string();
372                    }
373                    // Extract impl-trait blocks independently (Rust only; empty for other langs)
374                    if language == "rust" {
375                        semantic.impl_traits = extract_impl_traits(&source, &entry.path);
376                    }
377                    progress.fetch_add(1, Ordering::Relaxed);
378                    Some((entry.path.clone(), semantic))
379                }
380                Err(_) => {
381                    progress.fetch_add(1, Ordering::Relaxed);
382                    None
383                }
384            }
385        })
386        .collect();
387
388    // Check if cancelled after parallel processing
389    if ct.is_cancelled() {
390        return Err(AnalyzeError::Cancelled);
391    }
392
393    // Collect all impl-trait info from analysis results
394    let all_impl_traits: Vec<ImplTraitInfo> = analysis_results
395        .iter()
396        .flat_map(|(_, sem)| sem.impl_traits.iter().cloned())
397        .collect();
398
399    // Build call graph. Always build without impl_only filter first so we can
400    // record the unfiltered caller count before discarding those edges.
401    let mut graph = CallGraph::build_from_results(
402        analysis_results,
403        &all_impl_traits,
404        false, // filter applied below after counting
405    )?;
406
407    // Resolve symbol name using the requested match mode.
408    // Exact mode: check the graph directly without building a sorted set (O(1) lookups).
409    // Fuzzy modes: collect a sorted, deduplicated set of all known symbols for deterministic results.
410    let resolved_focus = if match_mode == SymbolMatchMode::Exact {
411        let exists = graph.definitions.contains_key(focus)
412            || graph.callers.contains_key(focus)
413            || graph.callees.contains_key(focus);
414        if exists {
415            focus.to_string()
416        } else {
417            return Err(crate::graph::GraphError::SymbolNotFound {
418                symbol: focus.to_string(),
419                hint: "Try match_mode=insensitive for a case-insensitive search.".to_string(),
420            }
421            .into());
422        }
423    } else {
424        let all_known: Vec<String> = graph
425            .definitions
426            .keys()
427            .chain(graph.callers.keys())
428            .chain(graph.callees.keys())
429            .cloned()
430            .collect::<std::collections::BTreeSet<_>>()
431            .into_iter()
432            .collect();
433        resolve_symbol(all_known.iter(), focus, &match_mode)?
434    };
435
436    // Count unique callers for the focus symbol before applying impl_only filter.
437    let unfiltered_caller_count = graph
438        .callers
439        .get(&resolved_focus)
440        .map(|edges| {
441            edges
442                .iter()
443                .map(|e| &e.neighbor_name)
444                .collect::<std::collections::HashSet<_>>()
445                .len()
446        })
447        .unwrap_or(0);
448
449    // Apply impl_only filter now if requested, then count filtered callers.
450    // Filter all caller adjacency lists so traversal and formatting are consistently
451    // restricted to impl-trait edges regardless of follow_depth.
452    let impl_trait_caller_count = if impl_only.unwrap_or(false) {
453        for edges in graph.callers.values_mut() {
454            edges.retain(|e| e.is_impl_trait);
455        }
456        graph
457            .callers
458            .get(&resolved_focus)
459            .map(|edges| {
460                edges
461                    .iter()
462                    .map(|e| &e.neighbor_name)
463                    .collect::<std::collections::HashSet<_>>()
464                    .len()
465            })
466            .unwrap_or(0)
467    } else {
468        unfiltered_caller_count
469    };
470
471    // Compute chain data for pagination (always, regardless of summary mode)
472    let def_count = graph
473        .definitions
474        .get(&resolved_focus)
475        .map_or(0, |d| d.len());
476    let incoming_chains = graph.find_incoming_chains(&resolved_focus, follow_depth)?;
477    let outgoing_chains = graph.find_outgoing_chains(&resolved_focus, follow_depth)?;
478
479    let (prod_chains, test_chains): (Vec<_>, Vec<_>) =
480        incoming_chains.into_iter().partition(|chain| {
481            chain
482                .chain
483                .first()
484                .is_none_or(|(name, path, _)| !is_test_file(path) && !name.starts_with("test_"))
485        });
486
487    // Format output
488    let formatted = if use_summary {
489        format_focused_summary(&graph, &resolved_focus, follow_depth, Some(root))?
490    } else {
491        format_focused(&graph, &resolved_focus, follow_depth, Some(root))?
492    };
493
494    Ok(FocusedAnalysisOutput {
495        formatted,
496        next_cursor: None,
497        prod_chains,
498        test_chains,
499        outgoing_chains,
500        def_count,
501        unfiltered_caller_count,
502        impl_trait_caller_count,
503    })
504}
505
506#[instrument(skip_all, fields(path = %root.display(), symbol = %focus))]
507pub fn analyze_focused(
508    root: &Path,
509    focus: &str,
510    follow_depth: u32,
511    max_depth: Option<u32>,
512    ast_recursion_limit: Option<usize>,
513) -> Result<FocusedAnalysisOutput, AnalyzeError> {
514    let counter = Arc::new(AtomicUsize::new(0));
515    let ct = CancellationToken::new();
516    analyze_focused_with_progress(
517        root,
518        focus,
519        SymbolMatchMode::Exact,
520        follow_depth,
521        max_depth,
522        ast_recursion_limit,
523        counter,
524        ct,
525        false,
526        None,
527    )
528}
529
530/// Analyze a single file and return a minimal fixed schema (name, line count, language,
531/// functions, imports) for lightweight code understanding.
532#[instrument(skip_all, fields(path))]
533pub fn analyze_module_file(path: &str) -> Result<crate::types::ModuleInfo, AnalyzeError> {
534    let source = std::fs::read_to_string(path)
535        .map_err(|e| AnalyzeError::Parser(crate::parser::ParserError::ParseError(e.to_string())))?;
536
537    let file_path = Path::new(path);
538    let name = file_path
539        .file_name()
540        .and_then(|s| s.to_str())
541        .unwrap_or("unknown")
542        .to_string();
543
544    let line_count = source.lines().count();
545
546    let language = file_path
547        .extension()
548        .and_then(|e| e.to_str())
549        .and_then(language_from_extension)
550        .ok_or_else(|| {
551            AnalyzeError::Parser(crate::parser::ParserError::ParseError(
552                "unsupported or missing file extension".to_string(),
553            ))
554        })?;
555
556    let semantic = SemanticExtractor::extract(&source, language, None)?;
557
558    let functions = semantic
559        .functions
560        .into_iter()
561        .map(|f| crate::types::ModuleFunctionInfo {
562            name: f.name,
563            line: f.line,
564        })
565        .collect();
566
567    let imports = semantic
568        .imports
569        .into_iter()
570        .map(|i| crate::types::ModuleImportInfo {
571            module: i.module,
572            items: i.items,
573        })
574        .collect();
575
576    Ok(crate::types::ModuleInfo {
577        name,
578        line_count,
579        language: language.to_string(),
580        functions,
581        imports,
582    })
583}
584
585/// Resolve Python wildcard imports to actual symbol names.
586///
587/// For each import with items=["*"], this function:
588/// 1. Parses the relative dots (if any) and climbs the directory tree
589/// 2. Finds the target .py file or __init__.py
590/// 3. Extracts symbols (functions and classes) from the target
591/// 4. Honors __all__ if defined, otherwise uses function+class names
592///
593/// All resolution failures are non-fatal: debug-logged and the wildcard is preserved.
594fn resolve_wildcard_imports(file_path: &Path, imports: &mut [ImportInfo]) {
595    use std::collections::HashMap;
596
597    let mut resolved_cache: HashMap<PathBuf, Vec<String>> = HashMap::new();
598    let file_path_canonical = match file_path.canonicalize() {
599        Ok(p) => p,
600        Err(_) => {
601            tracing::debug!(file = ?file_path, "unable to canonicalize current file path");
602            return;
603        }
604    };
605
606    for import in imports.iter_mut() {
607        if import.items != ["*"] {
608            continue;
609        }
610        resolve_single_wildcard(import, file_path, &file_path_canonical, &mut resolved_cache);
611    }
612}
613
614/// Resolve one wildcard import in place. On any failure the import is left unchanged.
615fn resolve_single_wildcard(
616    import: &mut ImportInfo,
617    file_path: &Path,
618    file_path_canonical: &Path,
619    resolved_cache: &mut std::collections::HashMap<PathBuf, Vec<String>>,
620) {
621    let module = import.module.clone();
622    let dot_count = module.chars().take_while(|c| *c == '.').count();
623    if dot_count == 0 {
624        return;
625    }
626    let module_path = module.trim_start_matches('.');
627
628    let target_to_read = match locate_target_file(file_path, dot_count, module_path, &module) {
629        Some(p) => p,
630        None => return,
631    };
632
633    let canonical = match target_to_read.canonicalize() {
634        Ok(p) => p,
635        Err(_) => {
636            tracing::debug!(target = ?target_to_read, import = %module, "unable to canonicalize path");
637            return;
638        }
639    };
640
641    if canonical == file_path_canonical {
642        tracing::debug!(target = ?canonical, import = %module, "cannot import from self");
643        return;
644    }
645
646    if let Some(cached) = resolved_cache.get(&canonical) {
647        tracing::debug!(import = %module, symbols_count = cached.len(), "using cached symbols");
648        import.items = cached.clone();
649        return;
650    }
651
652    if let Some(symbols) = parse_target_symbols(&target_to_read, &module) {
653        tracing::debug!(import = %module, resolved_count = symbols.len(), "wildcard import resolved");
654        import.items = symbols.clone();
655        resolved_cache.insert(canonical, symbols);
656    }
657}
658
659/// Locate the .py file that a wildcard import refers to. Returns None if not found.
660fn locate_target_file(
661    file_path: &Path,
662    dot_count: usize,
663    module_path: &str,
664    module: &str,
665) -> Option<PathBuf> {
666    let mut target_dir = file_path.parent()?.to_path_buf();
667
668    for _ in 1..dot_count {
669        if !target_dir.pop() {
670            tracing::debug!(import = %module, "unable to climb {} levels", dot_count.saturating_sub(1));
671            return None;
672        }
673    }
674
675    let target_file = if module_path.is_empty() {
676        target_dir.join("__init__.py")
677    } else {
678        let rel_path = module_path.replace('.', "/");
679        target_dir.join(format!("{rel_path}.py"))
680    };
681
682    if target_file.exists() {
683        Some(target_file)
684    } else if target_file.with_extension("").is_dir() {
685        let init = target_file.with_extension("").join("__init__.py");
686        if init.exists() { Some(init) } else { None }
687    } else {
688        tracing::debug!(target = ?target_file, import = %module, "target file not found");
689        None
690    }
691}
692
693/// Read and parse a target .py file, returning its exported symbols.
694fn parse_target_symbols(target_path: &Path, module: &str) -> Option<Vec<String>> {
695    let source = match std::fs::read_to_string(target_path) {
696        Ok(s) => s,
697        Err(e) => {
698            tracing::debug!(target = ?target_path, import = %module, error = %e, "unable to read target file");
699            return None;
700        }
701    };
702
703    // Parse once with tree-sitter
704    use tree_sitter::Parser;
705    let lang_info = crate::languages::get_language_info("python")?;
706    let mut parser = Parser::new();
707    if parser.set_language(&lang_info.language).is_err() {
708        return None;
709    }
710    let tree = parser.parse(&source, None)?;
711
712    // First, try to extract __all__ from the same tree
713    let mut symbols = Vec::new();
714    extract_all_from_tree(&tree, &source, &mut symbols);
715    if !symbols.is_empty() {
716        tracing::debug!(import = %module, symbols = ?symbols, "using __all__ symbols");
717        return Some(symbols);
718    }
719
720    // Fallback: extract functions/classes from the tree
721    let root = tree.root_node();
722    let mut cursor = root.walk();
723    for child in root.children(&mut cursor) {
724        match child.kind() {
725            "function_definition" => {
726                if let Some(name_node) = child.child_by_field_name("name") {
727                    let name = source[name_node.start_byte()..name_node.end_byte()].to_string();
728                    if !name.starts_with('_') {
729                        symbols.push(name);
730                    }
731                }
732            }
733            "class_definition" => {
734                if let Some(name_node) = child.child_by_field_name("name") {
735                    let name = source[name_node.start_byte()..name_node.end_byte()].to_string();
736                    if !name.starts_with('_') {
737                        symbols.push(name);
738                    }
739                }
740            }
741            _ => {}
742        }
743    }
744    tracing::debug!(import = %module, fallback_symbols = ?symbols, "using fallback function/class names");
745    Some(symbols)
746}
747
748/// Extract __all__ from a tree-sitter tree.
749fn extract_all_from_tree(tree: &tree_sitter::Tree, source: &str, result: &mut Vec<String>) {
750    let root = tree.root_node();
751    let mut cursor = root.walk();
752    for child in root.children(&mut cursor) {
753        if child.kind() == "simple_statement" {
754            // simple_statement contains assignment and other statement types
755            let mut simple_cursor = child.walk();
756            for simple_child in child.children(&mut simple_cursor) {
757                if simple_child.kind() == "assignment"
758                    && let Some(left) = simple_child.child_by_field_name("left")
759                {
760                    let target_text = source[left.start_byte()..left.end_byte()].trim();
761                    if target_text == "__all__"
762                        && let Some(right) = simple_child.child_by_field_name("right")
763                    {
764                        extract_string_list_from_list_node(&right, source, result);
765                    }
766                }
767            }
768        } else if child.kind() == "expression_statement" {
769            // Fallback for older Python AST structures
770            let mut stmt_cursor = child.walk();
771            for stmt_child in child.children(&mut stmt_cursor) {
772                if stmt_child.kind() == "assignment"
773                    && let Some(left) = stmt_child.child_by_field_name("left")
774                {
775                    let target_text = source[left.start_byte()..left.end_byte()].trim();
776                    if target_text == "__all__"
777                        && let Some(right) = stmt_child.child_by_field_name("right")
778                    {
779                        extract_string_list_from_list_node(&right, source, result);
780                    }
781                }
782            }
783        }
784    }
785}
786
787/// Extract string literals from a Python list node.
788fn extract_string_list_from_list_node(
789    list_node: &tree_sitter::Node,
790    source: &str,
791    result: &mut Vec<String>,
792) {
793    let mut cursor = list_node.walk();
794    for child in list_node.named_children(&mut cursor) {
795        if child.kind() == "string" {
796            let raw = source[child.start_byte()..child.end_byte()].trim();
797            // Strip quotes: "name" -> name
798            let unquoted = raw.trim_matches('"').trim_matches('\'').to_string();
799            if !unquoted.is_empty() {
800                result.push(unquoted);
801            }
802        }
803    }
804}
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809    use crate::formatter::format_focused_paginated;
810    use crate::pagination::{PaginationMode, decode_cursor, paginate_slice};
811    use std::fs;
812    use tempfile::TempDir;
813
814    #[test]
815    fn test_symbol_focus_callers_pagination_first_page() {
816        let temp_dir = TempDir::new().unwrap();
817
818        // Create a file with many callers of `target`
819        let mut code = String::from("fn target() {}\n");
820        for i in 0..15 {
821            code.push_str(&format!("fn caller_{:02}() {{ target(); }}\n", i));
822        }
823        fs::write(temp_dir.path().join("lib.rs"), &code).unwrap();
824
825        // Act
826        let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
827
828        // Paginate prod callers with page_size=5
829        let paginated = paginate_slice(&output.prod_chains, 0, 5, PaginationMode::Callers)
830            .expect("paginate failed");
831        assert!(
832            paginated.total >= 5,
833            "should have enough callers to paginate"
834        );
835        assert!(
836            paginated.next_cursor.is_some(),
837            "should have next_cursor for page 1"
838        );
839
840        // Verify cursor encodes callers mode
841        assert_eq!(paginated.items.len(), 5);
842    }
843
844    #[test]
845    fn test_symbol_focus_callers_pagination_second_page() {
846        let temp_dir = TempDir::new().unwrap();
847
848        let mut code = String::from("fn target() {}\n");
849        for i in 0..12 {
850            code.push_str(&format!("fn caller_{:02}() {{ target(); }}\n", i));
851        }
852        fs::write(temp_dir.path().join("lib.rs"), &code).unwrap();
853
854        let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
855        let total_prod = output.prod_chains.len();
856
857        if total_prod > 5 {
858            // Get page 1 cursor
859            let p1 = paginate_slice(&output.prod_chains, 0, 5, PaginationMode::Callers)
860                .expect("paginate failed");
861            assert!(p1.next_cursor.is_some());
862
863            let cursor_str = p1.next_cursor.unwrap();
864            let cursor_data = decode_cursor(&cursor_str).expect("decode failed");
865
866            // Get page 2
867            let p2 = paginate_slice(
868                &output.prod_chains,
869                cursor_data.offset,
870                5,
871                PaginationMode::Callers,
872            )
873            .expect("paginate failed");
874
875            // Format paginated output
876            let formatted = format_focused_paginated(
877                &p2.items,
878                total_prod,
879                PaginationMode::Callers,
880                "target",
881                &output.prod_chains,
882                &output.test_chains,
883                &output.outgoing_chains,
884                output.def_count,
885                cursor_data.offset,
886                Some(temp_dir.path()),
887                true,
888            );
889
890            // Assert: header shows correct range for page 2
891            let expected_start = cursor_data.offset + 1;
892            assert!(
893                formatted.contains(&format!("CALLERS ({}", expected_start)),
894                "header should show page 2 range, got: {}",
895                formatted
896            );
897        }
898    }
899
900    #[test]
901    fn test_symbol_focus_callees_pagination() {
902        let temp_dir = TempDir::new().unwrap();
903
904        // target calls many functions
905        let mut code = String::from("fn target() {\n");
906        for i in 0..10 {
907            code.push_str(&format!("    callee_{:02}();\n", i));
908        }
909        code.push_str("}\n");
910        for i in 0..10 {
911            code.push_str(&format!("fn callee_{:02}() {{}}\n", i));
912        }
913        fs::write(temp_dir.path().join("lib.rs"), &code).unwrap();
914
915        let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
916        let total_callees = output.outgoing_chains.len();
917
918        if total_callees > 3 {
919            let paginated = paginate_slice(&output.outgoing_chains, 0, 3, PaginationMode::Callees)
920                .expect("paginate failed");
921
922            let formatted = format_focused_paginated(
923                &paginated.items,
924                total_callees,
925                PaginationMode::Callees,
926                "target",
927                &output.prod_chains,
928                &output.test_chains,
929                &output.outgoing_chains,
930                output.def_count,
931                0,
932                Some(temp_dir.path()),
933                true,
934            );
935
936            assert!(
937                formatted.contains(&format!(
938                    "CALLEES (1-{} of {})",
939                    paginated.items.len(),
940                    total_callees
941                )),
942                "header should show callees range, got: {}",
943                formatted
944            );
945        }
946    }
947
948    #[test]
949    fn test_symbol_focus_empty_prod_callers() {
950        let temp_dir = TempDir::new().unwrap();
951
952        // target is only called from test functions
953        let code = r#"
954fn target() {}
955
956#[cfg(test)]
957mod tests {
958    use super::*;
959    #[test]
960    fn test_something() { target(); }
961}
962"#;
963        fs::write(temp_dir.path().join("lib.rs"), code).unwrap();
964
965        let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
966
967        // prod_chains may be empty; pagination should handle it gracefully
968        let paginated = paginate_slice(&output.prod_chains, 0, 100, PaginationMode::Callers)
969            .expect("paginate failed");
970        assert_eq!(paginated.items.len(), output.prod_chains.len());
971        assert!(
972            paginated.next_cursor.is_none(),
973            "no next_cursor for empty or single-page prod_chains"
974        );
975    }
976}