1use crate::formatter::{
7 format_file_details, format_focused_internal, format_focused_summary_internal, format_structure,
8};
9use crate::graph::{CallGraph, InternalCallChain};
10use crate::lang::language_from_extension;
11use crate::parser::{ElementExtractor, SemanticExtractor};
12use crate::test_detection::is_test_file;
13use crate::traversal::{WalkEntry, walk_directory};
14use crate::types::{
15 AnalysisMode, FileInfo, ImplTraitInfo, ImportInfo, SemanticAnalysis, SymbolMatchMode,
16};
17use rayon::prelude::*;
18use schemars::JsonSchema;
19use serde::Serialize;
20use std::path::{Path, PathBuf};
21use std::sync::Arc;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::time::Instant;
24use thiserror::Error;
25use tokio_util::sync::CancellationToken;
26use tracing::instrument;
27
28#[derive(Debug, Error)]
29pub enum AnalyzeError {
30 #[error("Traversal error: {0}")]
31 Traversal(#[from] crate::traversal::TraversalError),
32 #[error("Parser error: {0}")]
33 Parser(#[from] crate::parser::ParserError),
34 #[error("Graph error: {0}")]
35 Graph(#[from] crate::graph::GraphError),
36 #[error("Formatter error: {0}")]
37 Formatter(#[from] crate::formatter::FormatterError),
38 #[error("Analysis cancelled")]
39 Cancelled,
40}
41
42#[derive(Debug, Clone, Serialize, JsonSchema)]
44pub struct AnalysisOutput {
45 #[schemars(description = "Formatted text representation of the analysis")]
46 pub formatted: String,
47 #[schemars(description = "List of files analyzed in the directory")]
48 pub files: Vec<FileInfo>,
49 #[serde(skip)]
51 #[schemars(skip)]
52 pub entries: Vec<WalkEntry>,
53 #[serde(skip)]
55 #[schemars(skip)]
56 pub subtree_counts: Option<Vec<(std::path::PathBuf, usize)>>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 #[schemars(
59 description = "Opaque cursor token for the next page of results (absent when no more results)"
60 )]
61 pub next_cursor: Option<String>,
62}
63
64#[derive(Debug, Clone, Serialize, JsonSchema)]
66pub struct FileAnalysisOutput {
67 #[schemars(description = "Formatted text representation of the analysis")]
68 pub formatted: String,
69 #[schemars(description = "Semantic analysis data including functions, classes, and imports")]
70 pub semantic: SemanticAnalysis,
71 #[schemars(description = "Total line count of the analyzed file")]
72 #[schemars(schema_with = "crate::schema_helpers::integer_schema")]
73 pub line_count: usize,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 #[schemars(
76 description = "Opaque cursor token for the next page of results (absent when no more results)"
77 )]
78 pub next_cursor: Option<String>,
79}
80
81#[instrument(skip_all, fields(path = %root.display()))]
83#[allow(clippy::needless_pass_by_value)]
85pub fn analyze_directory_with_progress(
86 root: &Path,
87 entries: Vec<WalkEntry>,
88 progress: Arc<AtomicUsize>,
89 ct: CancellationToken,
90) -> Result<AnalysisOutput, AnalyzeError> {
91 if ct.is_cancelled() {
93 return Err(AnalyzeError::Cancelled);
94 }
95
96 let file_entries: Vec<&WalkEntry> = entries.iter().filter(|e| !e.is_dir).collect();
98
99 let start = Instant::now();
100 tracing::debug!(file_count = file_entries.len(), root = %root.display(), "analysis start");
101
102 let analysis_results: Vec<FileInfo> = file_entries
104 .par_iter()
105 .filter_map(|entry| {
106 if ct.is_cancelled() {
108 return None;
109 }
110
111 let path_str = entry.path.display().to_string();
112
113 let ext = entry.path.extension().and_then(|e| e.to_str());
115
116 let Ok(source) = std::fs::read_to_string(&entry.path) else {
118 progress.fetch_add(1, Ordering::Relaxed);
119 return None;
120 };
121
122 let line_count = source.lines().count();
124
125 let (language, function_count, class_count) = if let Some(ext_str) = ext {
127 if let Some(lang) = language_from_extension(ext_str) {
128 let lang_str = lang.to_string();
129 match ElementExtractor::extract_with_depth(&source, &lang_str) {
130 Ok((func_count, class_count)) => (lang_str, func_count, class_count),
131 Err(_) => (lang_str, 0, 0),
132 }
133 } else {
134 ("unknown".to_string(), 0, 0)
135 }
136 } else {
137 ("unknown".to_string(), 0, 0)
138 };
139
140 progress.fetch_add(1, Ordering::Relaxed);
141
142 let is_test = is_test_file(&entry.path);
143
144 Some(FileInfo {
145 path: path_str,
146 line_count,
147 function_count,
148 class_count,
149 language,
150 is_test,
151 })
152 })
153 .collect();
154
155 if ct.is_cancelled() {
157 return Err(AnalyzeError::Cancelled);
158 }
159
160 tracing::debug!(
161 file_count = file_entries.len(),
162 duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
163 "analysis complete"
164 );
165
166 let formatted = format_structure(&entries, &analysis_results, None);
168
169 Ok(AnalysisOutput {
170 formatted,
171 files: analysis_results,
172 entries,
173 next_cursor: None,
174 subtree_counts: None,
175 })
176}
177
178#[instrument(skip_all, fields(path = %root.display()))]
180pub fn analyze_directory(
181 root: &Path,
182 max_depth: Option<u32>,
183) -> Result<AnalysisOutput, AnalyzeError> {
184 let entries = walk_directory(root, max_depth)?;
185 let counter = Arc::new(AtomicUsize::new(0));
186 let ct = CancellationToken::new();
187 analyze_directory_with_progress(root, entries, counter, ct)
188}
189
190#[must_use]
192pub fn determine_mode(path: &str, focus: Option<&str>) -> AnalysisMode {
193 if focus.is_some() {
194 return AnalysisMode::SymbolFocus;
195 }
196
197 let path_obj = Path::new(path);
198 if path_obj.is_dir() {
199 AnalysisMode::Overview
200 } else {
201 AnalysisMode::FileDetails
202 }
203}
204
205#[instrument(skip_all, fields(path))]
207pub fn analyze_file(
208 path: &str,
209 ast_recursion_limit: Option<usize>,
210) -> Result<FileAnalysisOutput, AnalyzeError> {
211 let start = Instant::now();
212 let source = std::fs::read_to_string(path)
213 .map_err(|e| AnalyzeError::Parser(crate::parser::ParserError::ParseError(e.to_string())))?;
214
215 let line_count = source.lines().count();
216
217 let ext = Path::new(path)
219 .extension()
220 .and_then(|e| e.to_str())
221 .and_then(language_from_extension)
222 .map_or_else(|| "unknown".to_string(), std::string::ToString::to_string);
223
224 let mut semantic = SemanticExtractor::extract(&source, &ext, ast_recursion_limit)?;
226
227 for r in &mut semantic.references {
229 r.location = path.to_string();
230 }
231
232 if ext == "python" {
234 resolve_wildcard_imports(Path::new(path), &mut semantic.imports);
235 }
236
237 let is_test = is_test_file(Path::new(path));
239
240 let parent_dir = Path::new(path).parent();
242
243 let formatted = format_file_details(path, &semantic, line_count, is_test, parent_dir);
245
246 tracing::debug!(path = %path, language = %ext, functions = semantic.functions.len(), classes = semantic.classes.len(), imports = semantic.imports.len(), duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX), "file analysis complete");
247
248 Ok(FileAnalysisOutput {
249 formatted,
250 semantic,
251 line_count,
252 next_cursor: None,
253 })
254}
255
256#[derive(Debug, Serialize, JsonSchema)]
258pub struct FocusedAnalysisOutput {
259 #[schemars(description = "Formatted text representation of the call graph analysis")]
260 pub formatted: String,
261 #[serde(skip_serializing_if = "Option::is_none")]
262 #[schemars(
263 description = "Opaque cursor token for the next page of results (absent when no more results)"
264 )]
265 pub next_cursor: Option<String>,
266 #[serde(skip)]
269 #[schemars(skip)]
270 pub(crate) prod_chains: Vec<InternalCallChain>,
271 #[serde(skip)]
273 #[schemars(skip)]
274 pub(crate) test_chains: Vec<InternalCallChain>,
275 #[serde(skip)]
277 #[schemars(skip)]
278 pub(crate) outgoing_chains: Vec<InternalCallChain>,
279 #[serde(skip)]
281 #[schemars(skip)]
282 pub def_count: usize,
283 #[serde(skip)]
285 #[schemars(skip)]
286 pub unfiltered_caller_count: usize,
287 #[serde(skip)]
289 #[schemars(skip)]
290 pub impl_trait_caller_count: usize,
291}
292
293#[derive(Clone)]
296pub struct AnalyzeSymbolParams {
297 pub focus: String,
298 pub match_mode: SymbolMatchMode,
299 pub follow_depth: u32,
300 pub max_depth: Option<u32>,
301 pub ast_recursion_limit: Option<usize>,
302 pub use_summary: bool,
303 pub impl_only: Option<bool>,
304}
305
306#[derive(Clone)]
308struct FocusedAnalysisParams {
309 focus: String,
310 match_mode: SymbolMatchMode,
311 follow_depth: u32,
312 ast_recursion_limit: Option<usize>,
313 use_summary: bool,
314 impl_only: Option<bool>,
315}
316
317type AnalysisResults = (Vec<(PathBuf, SemanticAnalysis)>, Vec<ImplTraitInfo>);
319
320fn collect_file_analysis(
322 entries: &[WalkEntry],
323 progress: &Arc<AtomicUsize>,
324 ct: &CancellationToken,
325 ast_recursion_limit: Option<usize>,
326) -> Result<AnalysisResults, AnalyzeError> {
327 if ct.is_cancelled() {
329 return Err(AnalyzeError::Cancelled);
330 }
331
332 let file_entries: Vec<&WalkEntry> = entries.iter().filter(|e| !e.is_dir).collect();
335
336 let analysis_results: Vec<(PathBuf, SemanticAnalysis)> = file_entries
337 .par_iter()
338 .filter_map(|entry| {
339 if ct.is_cancelled() {
341 return None;
342 }
343
344 let ext = entry.path.extension().and_then(|e| e.to_str());
345
346 let Ok(source) = std::fs::read_to_string(&entry.path) else {
348 progress.fetch_add(1, Ordering::Relaxed);
349 return None;
350 };
351
352 let language = if let Some(ext_str) = ext {
354 language_from_extension(ext_str)
355 .map_or_else(|| "unknown".to_string(), std::string::ToString::to_string)
356 } else {
357 "unknown".to_string()
358 };
359
360 if let Ok(mut semantic) =
361 SemanticExtractor::extract(&source, &language, ast_recursion_limit)
362 {
363 for r in &mut semantic.references {
365 r.location = entry.path.display().to_string();
366 }
367 for trait_info in &mut semantic.impl_traits {
369 trait_info.path.clone_from(&entry.path);
370 }
371 progress.fetch_add(1, Ordering::Relaxed);
372 Some((entry.path.clone(), semantic))
373 } else {
374 progress.fetch_add(1, Ordering::Relaxed);
375 None
376 }
377 })
378 .collect();
379
380 if ct.is_cancelled() {
382 return Err(AnalyzeError::Cancelled);
383 }
384
385 let all_impl_traits: Vec<ImplTraitInfo> = analysis_results
387 .iter()
388 .flat_map(|(_, sem)| sem.impl_traits.iter().cloned())
389 .collect();
390
391 Ok((analysis_results, all_impl_traits))
392}
393
394fn build_call_graph(
396 analysis_results: Vec<(PathBuf, SemanticAnalysis)>,
397 all_impl_traits: &[ImplTraitInfo],
398) -> Result<CallGraph, AnalyzeError> {
399 CallGraph::build_from_results(
402 analysis_results,
403 all_impl_traits,
404 false, )
406 .map_err(std::convert::Into::into)
407}
408
409fn resolve_symbol(
414 graph: &mut CallGraph,
415 params: &FocusedAnalysisParams,
416) -> Result<(String, usize, usize), AnalyzeError> {
417 let resolved_focus = if params.match_mode == SymbolMatchMode::Exact {
419 let exists = graph.definitions.contains_key(¶ms.focus)
420 || graph.callers.contains_key(¶ms.focus)
421 || graph.callees.contains_key(¶ms.focus);
422 if exists {
423 params.focus.clone()
424 } else {
425 return Err(crate::graph::GraphError::SymbolNotFound {
426 symbol: params.focus.clone(),
427 hint: "Try match_mode=insensitive for a case-insensitive search.".to_string(),
428 }
429 .into());
430 }
431 } else {
432 graph.resolve_symbol_indexed(¶ms.focus, ¶ms.match_mode)?
433 };
434
435 let unfiltered_caller_count = graph.callers.get(&resolved_focus).map_or(0, |edges| {
437 edges
438 .iter()
439 .map(|e| &e.neighbor_name)
440 .collect::<std::collections::HashSet<_>>()
441 .len()
442 });
443
444 let impl_trait_caller_count = if params.impl_only.unwrap_or(false) {
448 for edges in graph.callers.values_mut() {
449 edges.retain(|e| e.is_impl_trait);
450 }
451 graph.callers.get(&resolved_focus).map_or(0, |edges| {
452 edges
453 .iter()
454 .map(|e| &e.neighbor_name)
455 .collect::<std::collections::HashSet<_>>()
456 .len()
457 })
458 } else {
459 unfiltered_caller_count
460 };
461
462 Ok((
463 resolved_focus,
464 unfiltered_caller_count,
465 impl_trait_caller_count,
466 ))
467}
468
469type ChainComputeResult = (
471 String,
472 Vec<InternalCallChain>,
473 Vec<InternalCallChain>,
474 Vec<InternalCallChain>,
475 usize,
476);
477
478fn compute_chains(
480 graph: &CallGraph,
481 resolved_focus: &str,
482 root: &Path,
483 params: &FocusedAnalysisParams,
484 unfiltered_caller_count: usize,
485 impl_trait_caller_count: usize,
486) -> Result<ChainComputeResult, AnalyzeError> {
487 let def_count = graph.definitions.get(resolved_focus).map_or(0, Vec::len);
489 let incoming_chains = graph.find_incoming_chains(resolved_focus, params.follow_depth)?;
490 let outgoing_chains = graph.find_outgoing_chains(resolved_focus, params.follow_depth)?;
491
492 let (prod_chains, test_chains): (Vec<_>, Vec<_>) =
493 incoming_chains.iter().cloned().partition(|chain| {
494 chain
495 .chain
496 .first()
497 .is_none_or(|(name, path, _)| !is_test_file(path) && !name.starts_with("test_"))
498 });
499
500 let mut formatted = if params.use_summary {
502 format_focused_summary_internal(
503 graph,
504 resolved_focus,
505 params.follow_depth,
506 Some(root),
507 Some(&incoming_chains),
508 Some(&outgoing_chains),
509 )?
510 } else {
511 format_focused_internal(
512 graph,
513 resolved_focus,
514 params.follow_depth,
515 Some(root),
516 Some(&incoming_chains),
517 Some(&outgoing_chains),
518 )?
519 };
520
521 if params.impl_only.unwrap_or(false) {
523 let filter_header = format!(
524 "FILTER: impl_only=true ({impl_trait_caller_count} of {unfiltered_caller_count} callers shown)\n",
525 );
526 formatted = format!("{filter_header}{formatted}");
527 }
528
529 Ok((
530 formatted,
531 prod_chains,
532 test_chains,
533 outgoing_chains,
534 def_count,
535 ))
536}
537
538#[allow(clippy::needless_pass_by_value)]
541pub fn analyze_focused_with_progress(
542 root: &Path,
543 params: &AnalyzeSymbolParams,
544 progress: Arc<AtomicUsize>,
545 ct: CancellationToken,
546) -> Result<FocusedAnalysisOutput, AnalyzeError> {
547 let entries = walk_directory(root, params.max_depth)?;
548 let internal_params = FocusedAnalysisParams {
549 focus: params.focus.clone(),
550 match_mode: params.match_mode.clone(),
551 follow_depth: params.follow_depth,
552 ast_recursion_limit: params.ast_recursion_limit,
553 use_summary: params.use_summary,
554 impl_only: params.impl_only,
555 };
556 analyze_focused_with_progress_with_entries_internal(
557 root,
558 params.max_depth,
559 &progress,
560 &ct,
561 &internal_params,
562 &entries,
563 )
564}
565
566#[instrument(skip_all, fields(path = %root.display(), symbol = %params.focus))]
568fn analyze_focused_with_progress_with_entries_internal(
569 root: &Path,
570 _max_depth: Option<u32>,
571 progress: &Arc<AtomicUsize>,
572 ct: &CancellationToken,
573 params: &FocusedAnalysisParams,
574 entries: &[WalkEntry],
575) -> Result<FocusedAnalysisOutput, AnalyzeError> {
576 if ct.is_cancelled() {
578 return Err(AnalyzeError::Cancelled);
579 }
580
581 if root.is_file() {
583 let formatted =
584 "Single-file focus not supported. Please provide a directory path for cross-file call graph analysis.\n"
585 .to_string();
586 return Ok(FocusedAnalysisOutput {
587 formatted,
588 next_cursor: None,
589 prod_chains: vec![],
590 test_chains: vec![],
591 outgoing_chains: vec![],
592 def_count: 0,
593 unfiltered_caller_count: 0,
594 impl_trait_caller_count: 0,
595 });
596 }
597
598 let (analysis_results, all_impl_traits) =
600 collect_file_analysis(entries, progress, ct, params.ast_recursion_limit)?;
601
602 if ct.is_cancelled() {
604 return Err(AnalyzeError::Cancelled);
605 }
606
607 let mut graph = build_call_graph(analysis_results, &all_impl_traits)?;
609
610 if ct.is_cancelled() {
612 return Err(AnalyzeError::Cancelled);
613 }
614
615 let (resolved_focus, unfiltered_caller_count, impl_trait_caller_count) =
617 resolve_symbol(&mut graph, params)?;
618
619 if ct.is_cancelled() {
621 return Err(AnalyzeError::Cancelled);
622 }
623
624 let (formatted, prod_chains, test_chains, outgoing_chains, def_count) = compute_chains(
626 &graph,
627 &resolved_focus,
628 root,
629 params,
630 unfiltered_caller_count,
631 impl_trait_caller_count,
632 )?;
633
634 Ok(FocusedAnalysisOutput {
635 formatted,
636 next_cursor: None,
637 prod_chains,
638 test_chains,
639 outgoing_chains,
640 def_count,
641 unfiltered_caller_count,
642 impl_trait_caller_count,
643 })
644}
645
646pub(crate) fn analyze_focused_with_progress_with_entries(
648 root: &Path,
649 params: &AnalyzeSymbolParams,
650 progress: &Arc<AtomicUsize>,
651 ct: &CancellationToken,
652 entries: &[WalkEntry],
653) -> Result<FocusedAnalysisOutput, AnalyzeError> {
654 let internal_params = FocusedAnalysisParams {
655 focus: params.focus.clone(),
656 match_mode: params.match_mode.clone(),
657 follow_depth: params.follow_depth,
658 ast_recursion_limit: params.ast_recursion_limit,
659 use_summary: params.use_summary,
660 impl_only: params.impl_only,
661 };
662 analyze_focused_with_progress_with_entries_internal(
663 root,
664 params.max_depth,
665 progress,
666 ct,
667 &internal_params,
668 entries,
669 )
670}
671
672#[instrument(skip_all, fields(path = %root.display(), symbol = %focus))]
673pub fn analyze_focused(
674 root: &Path,
675 focus: &str,
676 follow_depth: u32,
677 max_depth: Option<u32>,
678 ast_recursion_limit: Option<usize>,
679) -> Result<FocusedAnalysisOutput, AnalyzeError> {
680 let entries = walk_directory(root, max_depth)?;
681 let counter = Arc::new(AtomicUsize::new(0));
682 let ct = CancellationToken::new();
683 let params = AnalyzeSymbolParams {
684 focus: focus.to_string(),
685 match_mode: SymbolMatchMode::Exact,
686 follow_depth,
687 max_depth,
688 ast_recursion_limit,
689 use_summary: false,
690 impl_only: None,
691 };
692 analyze_focused_with_progress_with_entries(root, ¶ms, &counter, &ct, &entries)
693}
694
695#[instrument(skip_all, fields(path))]
698pub fn analyze_module_file(path: &str) -> Result<crate::types::ModuleInfo, AnalyzeError> {
699 let source = std::fs::read_to_string(path)
700 .map_err(|e| AnalyzeError::Parser(crate::parser::ParserError::ParseError(e.to_string())))?;
701
702 let file_path = Path::new(path);
703 let name = file_path
704 .file_name()
705 .and_then(|s| s.to_str())
706 .unwrap_or("unknown")
707 .to_string();
708
709 let line_count = source.lines().count();
710
711 let language = file_path
712 .extension()
713 .and_then(|e| e.to_str())
714 .and_then(language_from_extension)
715 .ok_or_else(|| {
716 AnalyzeError::Parser(crate::parser::ParserError::ParseError(
717 "unsupported or missing file extension".to_string(),
718 ))
719 })?;
720
721 let semantic = SemanticExtractor::extract(&source, language, None)?;
722
723 let functions = semantic
724 .functions
725 .into_iter()
726 .map(|f| crate::types::ModuleFunctionInfo {
727 name: f.name,
728 line: f.line,
729 })
730 .collect();
731
732 let imports = semantic
733 .imports
734 .into_iter()
735 .map(|i| crate::types::ModuleImportInfo {
736 module: i.module,
737 items: i.items,
738 })
739 .collect();
740
741 Ok(crate::types::ModuleInfo {
742 name,
743 line_count,
744 language: language.to_string(),
745 functions,
746 imports,
747 })
748}
749
750fn resolve_wildcard_imports(file_path: &Path, imports: &mut [ImportInfo]) {
760 use std::collections::HashMap;
761
762 let mut resolved_cache: HashMap<PathBuf, Vec<String>> = HashMap::new();
763 let Ok(file_path_canonical) = file_path.canonicalize() else {
764 tracing::debug!(file = ?file_path, "unable to canonicalize current file path");
765 return;
766 };
767
768 for import in imports.iter_mut() {
769 if import.items != ["*"] {
770 continue;
771 }
772 resolve_single_wildcard(import, file_path, &file_path_canonical, &mut resolved_cache);
773 }
774}
775
776fn resolve_single_wildcard(
778 import: &mut ImportInfo,
779 file_path: &Path,
780 file_path_canonical: &Path,
781 resolved_cache: &mut std::collections::HashMap<PathBuf, Vec<String>>,
782) {
783 let module = import.module.clone();
784 let dot_count = module.chars().take_while(|c| *c == '.').count();
785 if dot_count == 0 {
786 return;
787 }
788 let module_path = module.trim_start_matches('.');
789
790 let Some(target_to_read) = locate_target_file(file_path, dot_count, module_path, &module)
791 else {
792 return;
793 };
794
795 let Ok(canonical) = target_to_read.canonicalize() else {
796 tracing::debug!(target = ?target_to_read, import = %module, "unable to canonicalize path");
797 return;
798 };
799
800 if canonical == file_path_canonical {
801 tracing::debug!(target = ?canonical, import = %module, "cannot import from self");
802 return;
803 }
804
805 if let Some(cached) = resolved_cache.get(&canonical) {
806 tracing::debug!(import = %module, symbols_count = cached.len(), "using cached symbols");
807 import.items.clone_from(cached);
808 return;
809 }
810
811 if let Some(symbols) = parse_target_symbols(&target_to_read, &module) {
812 tracing::debug!(import = %module, resolved_count = symbols.len(), "wildcard import resolved");
813 import.items.clone_from(&symbols);
814 resolved_cache.insert(canonical, symbols);
815 }
816}
817
818fn locate_target_file(
820 file_path: &Path,
821 dot_count: usize,
822 module_path: &str,
823 module: &str,
824) -> Option<PathBuf> {
825 let mut target_dir = file_path.parent()?.to_path_buf();
826
827 for _ in 1..dot_count {
828 if !target_dir.pop() {
829 tracing::debug!(import = %module, "unable to climb {} levels", dot_count.saturating_sub(1));
830 return None;
831 }
832 }
833
834 let target_file = if module_path.is_empty() {
835 target_dir.join("__init__.py")
836 } else {
837 let rel_path = module_path.replace('.', "/");
838 target_dir.join(format!("{rel_path}.py"))
839 };
840
841 if target_file.exists() {
842 Some(target_file)
843 } else if target_file.with_extension("").is_dir() {
844 let init = target_file.with_extension("").join("__init__.py");
845 if init.exists() { Some(init) } else { None }
846 } else {
847 tracing::debug!(target = ?target_file, import = %module, "target file not found");
848 None
849 }
850}
851
852fn parse_target_symbols(target_path: &Path, module: &str) -> Option<Vec<String>> {
854 use tree_sitter::Parser;
855
856 let source = match std::fs::read_to_string(target_path) {
857 Ok(s) => s,
858 Err(e) => {
859 tracing::debug!(target = ?target_path, import = %module, error = %e, "unable to read target file");
860 return None;
861 }
862 };
863
864 let lang_info = crate::languages::get_language_info("python")?;
866 let mut parser = Parser::new();
867 if parser.set_language(&lang_info.language).is_err() {
868 return None;
869 }
870 let tree = parser.parse(&source, None)?;
871
872 let mut symbols = Vec::new();
874 extract_all_from_tree(&tree, &source, &mut symbols);
875 if !symbols.is_empty() {
876 tracing::debug!(import = %module, symbols = ?symbols, "using __all__ symbols");
877 return Some(symbols);
878 }
879
880 let root = tree.root_node();
882 let mut cursor = root.walk();
883 for child in root.children(&mut cursor) {
884 if matches!(child.kind(), "function_definition" | "class_definition")
885 && let Some(name_node) = child.child_by_field_name("name")
886 {
887 let name = source[name_node.start_byte()..name_node.end_byte()].to_string();
888 if !name.starts_with('_') {
889 symbols.push(name);
890 }
891 }
892 }
893 tracing::debug!(import = %module, fallback_symbols = ?symbols, "using fallback function/class names");
894 Some(symbols)
895}
896
897fn extract_all_from_tree(tree: &tree_sitter::Tree, source: &str, result: &mut Vec<String>) {
899 let root = tree.root_node();
900 let mut cursor = root.walk();
901 for child in root.children(&mut cursor) {
902 if child.kind() == "simple_statement" {
903 let mut simple_cursor = child.walk();
905 for simple_child in child.children(&mut simple_cursor) {
906 if simple_child.kind() == "assignment"
907 && let Some(left) = simple_child.child_by_field_name("left")
908 {
909 let target_text = source[left.start_byte()..left.end_byte()].trim();
910 if target_text == "__all__"
911 && let Some(right) = simple_child.child_by_field_name("right")
912 {
913 extract_string_list_from_list_node(&right, source, result);
914 }
915 }
916 }
917 } else if child.kind() == "expression_statement" {
918 let mut stmt_cursor = child.walk();
920 for stmt_child in child.children(&mut stmt_cursor) {
921 if stmt_child.kind() == "assignment"
922 && let Some(left) = stmt_child.child_by_field_name("left")
923 {
924 let target_text = source[left.start_byte()..left.end_byte()].trim();
925 if target_text == "__all__"
926 && let Some(right) = stmt_child.child_by_field_name("right")
927 {
928 extract_string_list_from_list_node(&right, source, result);
929 }
930 }
931 }
932 }
933 }
934}
935
936fn extract_string_list_from_list_node(
938 list_node: &tree_sitter::Node,
939 source: &str,
940 result: &mut Vec<String>,
941) {
942 let mut cursor = list_node.walk();
943 for child in list_node.named_children(&mut cursor) {
944 if child.kind() == "string" {
945 let raw = source[child.start_byte()..child.end_byte()].trim();
946 let unquoted = raw.trim_matches('"').trim_matches('\'').to_string();
948 if !unquoted.is_empty() {
949 result.push(unquoted);
950 }
951 }
952 }
953}
954
955#[cfg(test)]
956mod tests {
957 use super::*;
958 use crate::formatter::format_focused_paginated;
959 use crate::pagination::{PaginationMode, decode_cursor, paginate_slice};
960 use std::fs;
961 use tempfile::TempDir;
962
963 #[test]
964 fn test_symbol_focus_callers_pagination_first_page() {
965 let temp_dir = TempDir::new().unwrap();
966
967 let mut code = String::from("fn target() {}\n");
969 for i in 0..15 {
970 code.push_str(&format!("fn caller_{:02}() {{ target(); }}\n", i));
971 }
972 fs::write(temp_dir.path().join("lib.rs"), &code).unwrap();
973
974 let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
976
977 let paginated = paginate_slice(&output.prod_chains, 0, 5, PaginationMode::Callers)
979 .expect("paginate failed");
980 assert!(
981 paginated.total >= 5,
982 "should have enough callers to paginate"
983 );
984 assert!(
985 paginated.next_cursor.is_some(),
986 "should have next_cursor for page 1"
987 );
988
989 assert_eq!(paginated.items.len(), 5);
991 }
992
993 #[test]
994 fn test_symbol_focus_callers_pagination_second_page() {
995 let temp_dir = TempDir::new().unwrap();
996
997 let mut code = String::from("fn target() {}\n");
998 for i in 0..12 {
999 code.push_str(&format!("fn caller_{:02}() {{ target(); }}\n", i));
1000 }
1001 fs::write(temp_dir.path().join("lib.rs"), &code).unwrap();
1002
1003 let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
1004 let total_prod = output.prod_chains.len();
1005
1006 if total_prod > 5 {
1007 let p1 = paginate_slice(&output.prod_chains, 0, 5, PaginationMode::Callers)
1009 .expect("paginate failed");
1010 assert!(p1.next_cursor.is_some());
1011
1012 let cursor_str = p1.next_cursor.unwrap();
1013 let cursor_data = decode_cursor(&cursor_str).expect("decode failed");
1014
1015 let p2 = paginate_slice(
1017 &output.prod_chains,
1018 cursor_data.offset,
1019 5,
1020 PaginationMode::Callers,
1021 )
1022 .expect("paginate failed");
1023
1024 let formatted = format_focused_paginated(
1026 &p2.items,
1027 total_prod,
1028 PaginationMode::Callers,
1029 "target",
1030 &output.prod_chains,
1031 &output.test_chains,
1032 &output.outgoing_chains,
1033 output.def_count,
1034 cursor_data.offset,
1035 Some(temp_dir.path()),
1036 true,
1037 );
1038
1039 let expected_start = cursor_data.offset + 1;
1041 assert!(
1042 formatted.contains(&format!("CALLERS ({}", expected_start)),
1043 "header should show page 2 range, got: {}",
1044 formatted
1045 );
1046 }
1047 }
1048
1049 #[test]
1050 fn test_symbol_focus_callees_pagination() {
1051 let temp_dir = TempDir::new().unwrap();
1052
1053 let mut code = String::from("fn target() {\n");
1055 for i in 0..10 {
1056 code.push_str(&format!(" callee_{:02}();\n", i));
1057 }
1058 code.push_str("}\n");
1059 for i in 0..10 {
1060 code.push_str(&format!("fn callee_{:02}() {{}}\n", i));
1061 }
1062 fs::write(temp_dir.path().join("lib.rs"), &code).unwrap();
1063
1064 let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
1065 let total_callees = output.outgoing_chains.len();
1066
1067 if total_callees > 3 {
1068 let paginated = paginate_slice(&output.outgoing_chains, 0, 3, PaginationMode::Callees)
1069 .expect("paginate failed");
1070
1071 let formatted = format_focused_paginated(
1072 &paginated.items,
1073 total_callees,
1074 PaginationMode::Callees,
1075 "target",
1076 &output.prod_chains,
1077 &output.test_chains,
1078 &output.outgoing_chains,
1079 output.def_count,
1080 0,
1081 Some(temp_dir.path()),
1082 true,
1083 );
1084
1085 assert!(
1086 formatted.contains(&format!(
1087 "CALLEES (1-{} of {})",
1088 paginated.items.len(),
1089 total_callees
1090 )),
1091 "header should show callees range, got: {}",
1092 formatted
1093 );
1094 }
1095 }
1096
1097 #[test]
1098 fn test_symbol_focus_empty_prod_callers() {
1099 let temp_dir = TempDir::new().unwrap();
1100
1101 let code = r#"
1103fn target() {}
1104
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108 #[test]
1109 fn test_something() { target(); }
1110}
1111"#;
1112 fs::write(temp_dir.path().join("lib.rs"), code).unwrap();
1113
1114 let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
1115
1116 let paginated = paginate_slice(&output.prod_chains, 0, 100, PaginationMode::Callers)
1118 .expect("paginate failed");
1119 assert_eq!(paginated.items.len(), output.prod_chains.len());
1120 assert!(
1121 paginated.next_cursor.is_none(),
1122 "no next_cursor for empty or single-page prod_chains"
1123 );
1124 }
1125
1126 #[test]
1127 fn test_impl_only_filter_header_correct_counts() {
1128 let temp_dir = TempDir::new().unwrap();
1129
1130 let code = r#"
1135trait MyTrait {
1136 fn focus_symbol();
1137}
1138
1139struct SomeType;
1140
1141impl MyTrait for SomeType {
1142 fn focus_symbol() {}
1143}
1144
1145fn impl_caller() {
1146 SomeType::focus_symbol();
1147}
1148
1149fn regular_caller() {
1150 SomeType::focus_symbol();
1151}
1152"#;
1153 fs::write(temp_dir.path().join("lib.rs"), code).unwrap();
1154
1155 let params = AnalyzeSymbolParams {
1157 focus: "focus_symbol".to_string(),
1158 match_mode: SymbolMatchMode::Insensitive,
1159 follow_depth: 1,
1160 max_depth: None,
1161 ast_recursion_limit: None,
1162 use_summary: false,
1163 impl_only: Some(true),
1164 };
1165 let output = analyze_focused_with_progress(
1166 temp_dir.path(),
1167 ¶ms,
1168 Arc::new(AtomicUsize::new(0)),
1169 CancellationToken::new(),
1170 )
1171 .unwrap();
1172
1173 assert!(
1175 output.formatted.contains("FILTER: impl_only=true"),
1176 "formatted output should contain FILTER header for impl_only=true, got: {}",
1177 output.formatted
1178 );
1179
1180 assert!(
1182 output.impl_trait_caller_count < output.unfiltered_caller_count,
1183 "impl_trait_caller_count ({}) should be less than unfiltered_caller_count ({})",
1184 output.impl_trait_caller_count,
1185 output.unfiltered_caller_count
1186 );
1187
1188 let filter_line = output
1190 .formatted
1191 .lines()
1192 .find(|line| line.contains("FILTER: impl_only=true"))
1193 .expect("should find FILTER line");
1194 assert!(
1195 filter_line.contains(&format!(
1196 "({} of {} callers shown)",
1197 output.impl_trait_caller_count, output.unfiltered_caller_count
1198 )),
1199 "FILTER line should show correct N of M counts, got: {}",
1200 filter_line
1201 );
1202 }
1203
1204 #[test]
1205 fn test_callers_count_matches_formatted_output() {
1206 let temp_dir = TempDir::new().unwrap();
1207
1208 let code = r#"
1210fn target() {}
1211fn caller_a() { target(); }
1212fn caller_b() { target(); }
1213fn caller_c() { target(); }
1214"#;
1215 fs::write(temp_dir.path().join("lib.rs"), code).unwrap();
1216
1217 let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
1219
1220 let formatted = &output.formatted;
1222 let callers_count_from_output = formatted
1223 .lines()
1224 .find(|line| line.contains("FOCUS:"))
1225 .and_then(|line| {
1226 line.split(',')
1227 .find(|part| part.contains("callers"))
1228 .and_then(|part| {
1229 part.trim()
1230 .split_whitespace()
1231 .next()
1232 .and_then(|s| s.parse::<usize>().ok())
1233 })
1234 })
1235 .expect("should find CALLERS count in formatted output");
1236
1237 let expected_callers_count = output
1239 .prod_chains
1240 .iter()
1241 .filter_map(|chain| chain.chain.first().map(|(name, _, _)| name))
1242 .collect::<std::collections::HashSet<_>>()
1243 .len();
1244
1245 assert_eq!(
1246 callers_count_from_output, expected_callers_count,
1247 "CALLERS count in formatted output should match unique-first-caller count in prod_chains"
1248 );
1249 }
1250}