use crate::formatter::{
format_file_details, format_focused_internal, format_focused_summary_internal, format_structure,
};
use crate::graph::{CallGraph, InternalCallChain};
use crate::lang::{language_for_extension, supported_languages};
use crate::parser::{ElementExtractor, SemanticExtractor};
use crate::test_detection::is_test_file;
use crate::traversal::{WalkEntry, walk_directory};
use crate::types::{
AnalysisMode, FileInfo, ImplTraitInfo, ImportInfo, SemanticAnalysis, SymbolMatchMode,
};
use rayon::prelude::*;
#[cfg(feature = "schemars")]
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Instant;
use thiserror::Error;
use tokio_util::sync::CancellationToken;
use tracing::instrument;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum AnalyzeError {
#[error("Traversal error: {0}")]
Traversal(#[from] crate::traversal::TraversalError),
#[error("Parser error: {0}")]
Parser(#[from] crate::parser::ParserError),
#[error("Graph error: {0}")]
Graph(#[from] crate::graph::GraphError),
#[error("Formatter error: {0}")]
Formatter(#[from] crate::formatter::FormatterError),
#[error("Analysis cancelled")]
Cancelled,
#[error("unsupported language: {0}")]
UnsupportedLanguage(String),
}
#[derive(Debug, Clone, Serialize)]
#[cfg_attr(feature = "schemars", derive(JsonSchema))]
#[non_exhaustive]
pub struct AnalysisOutput {
#[cfg_attr(
feature = "schemars",
schemars(description = "Formatted text representation of the analysis")
)]
pub formatted: String,
#[cfg_attr(
feature = "schemars",
schemars(description = "List of files analyzed in the directory")
)]
pub files: Vec<FileInfo>,
#[serde(skip)]
#[cfg_attr(feature = "schemars", schemars(skip))]
pub entries: Vec<WalkEntry>,
#[serde(skip)]
#[cfg_attr(feature = "schemars", schemars(skip))]
pub subtree_counts: Option<Vec<(std::path::PathBuf, usize)>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[cfg_attr(
feature = "schemars",
schemars(
description = "Opaque cursor token for the next page of results (absent when no more results)"
)
)]
pub next_cursor: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[cfg_attr(feature = "schemars", derive(JsonSchema))]
#[non_exhaustive]
pub struct FileAnalysisOutput {
#[cfg_attr(
feature = "schemars",
schemars(description = "Formatted text representation of the analysis")
)]
pub formatted: String,
#[cfg_attr(
feature = "schemars",
schemars(description = "Semantic analysis data including functions, classes, and imports")
)]
pub semantic: SemanticAnalysis,
#[cfg_attr(
feature = "schemars",
schemars(description = "Total line count of the analyzed file")
)]
#[cfg_attr(
feature = "schemars",
schemars(schema_with = "crate::schema_helpers::integer_schema")
)]
pub line_count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
#[cfg_attr(
feature = "schemars",
schemars(
description = "Opaque cursor token for the next page of results (absent when no more results)"
)
)]
pub next_cursor: Option<String>,
}
impl FileAnalysisOutput {
#[must_use]
pub fn new(
formatted: String,
semantic: SemanticAnalysis,
line_count: usize,
next_cursor: Option<String>,
) -> Self {
Self {
formatted,
semantic,
line_count,
next_cursor,
}
}
}
#[instrument(skip_all, fields(path = %root.display()))]
#[allow(clippy::needless_pass_by_value)]
pub fn analyze_directory_with_progress(
root: &Path,
entries: Vec<WalkEntry>,
progress: Arc<AtomicUsize>,
ct: CancellationToken,
) -> Result<AnalysisOutput, AnalyzeError> {
if ct.is_cancelled() {
return Err(AnalyzeError::Cancelled);
}
let file_entries: Vec<&WalkEntry> = entries.iter().filter(|e| !e.is_dir).collect();
let start = Instant::now();
tracing::debug!(file_count = file_entries.len(), root = %root.display(), "analysis start");
let analysis_results: Vec<FileInfo> = file_entries
.par_iter()
.filter_map(|entry| {
if ct.is_cancelled() {
return None;
}
let path_str = entry.path.display().to_string();
let ext = entry.path.extension().and_then(|e| e.to_str());
let Ok(source) = std::fs::read_to_string(&entry.path) else {
progress.fetch_add(1, Ordering::Relaxed);
return None;
};
let line_count = source.lines().count();
let (language, function_count, class_count) = if let Some(ext_str) = ext {
if let Some(lang) = language_for_extension(ext_str) {
let lang_str = lang.to_string();
match ElementExtractor::extract_with_depth(&source, &lang_str) {
Ok((func_count, class_count)) => (lang_str, func_count, class_count),
Err(_) => (lang_str, 0, 0),
}
} else {
("unknown".to_string(), 0, 0)
}
} else {
("unknown".to_string(), 0, 0)
};
progress.fetch_add(1, Ordering::Relaxed);
let is_test = is_test_file(&entry.path);
Some(FileInfo {
path: path_str,
line_count,
function_count,
class_count,
language,
is_test,
})
})
.collect();
if ct.is_cancelled() {
return Err(AnalyzeError::Cancelled);
}
tracing::debug!(
file_count = file_entries.len(),
duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
"analysis complete"
);
let formatted = format_structure(&entries, &analysis_results, None);
Ok(AnalysisOutput {
formatted,
files: analysis_results,
entries,
next_cursor: None,
subtree_counts: None,
})
}
#[instrument(skip_all, fields(path = %root.display()))]
pub fn analyze_directory(
root: &Path,
max_depth: Option<u32>,
) -> Result<AnalysisOutput, AnalyzeError> {
let entries = walk_directory(root, max_depth)?;
let counter = Arc::new(AtomicUsize::new(0));
let ct = CancellationToken::new();
analyze_directory_with_progress(root, entries, counter, ct)
}
#[must_use]
pub fn determine_mode(path: &str, focus: Option<&str>) -> AnalysisMode {
if focus.is_some() {
return AnalysisMode::SymbolFocus;
}
let path_obj = Path::new(path);
if path_obj.is_dir() {
AnalysisMode::Overview
} else {
AnalysisMode::FileDetails
}
}
#[instrument(skip_all, fields(path))]
pub fn analyze_file(
path: &str,
ast_recursion_limit: Option<usize>,
) -> Result<FileAnalysisOutput, AnalyzeError> {
let start = Instant::now();
let source = std::fs::read_to_string(path)
.map_err(|e| AnalyzeError::Parser(crate::parser::ParserError::ParseError(e.to_string())))?;
let line_count = source.lines().count();
let ext = Path::new(path)
.extension()
.and_then(|e| e.to_str())
.and_then(language_for_extension)
.map_or_else(|| "unknown".to_string(), std::string::ToString::to_string);
let mut semantic = SemanticExtractor::extract(&source, &ext, ast_recursion_limit)?;
for r in &mut semantic.references {
r.location = path.to_string();
}
if ext == "python" {
resolve_wildcard_imports(Path::new(path), &mut semantic.imports);
}
let is_test = is_test_file(Path::new(path));
let parent_dir = Path::new(path).parent();
let formatted = format_file_details(path, &semantic, line_count, is_test, parent_dir);
tracing::debug!(path = %path, language = %ext, functions = semantic.functions.len(), classes = semantic.classes.len(), imports = semantic.imports.len(), duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX), "file analysis complete");
Ok(FileAnalysisOutput::new(
formatted, semantic, line_count, None,
))
}
#[inline]
pub fn analyze_str(
source: &str,
language: &str,
ast_recursion_limit: Option<usize>,
) -> Result<FileAnalysisOutput, AnalyzeError> {
let lang = language_for_extension(language).or_else(|| {
let lower = language.to_ascii_lowercase();
supported_languages()
.iter()
.find(|&&name| name == lower)
.copied()
});
let lang = lang.ok_or_else(|| AnalyzeError::UnsupportedLanguage(language.to_string()))?;
let mut semantic = SemanticExtractor::extract(source, lang, ast_recursion_limit)?;
for r in &mut semantic.references {
r.location = "<memory>".to_string();
}
let line_count = source.lines().count();
let formatted = format_file_details("", &semantic, line_count, false, None);
Ok(FileAnalysisOutput::new(
formatted, semantic, line_count, None,
))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(JsonSchema))]
pub struct CallChainEntry {
#[cfg_attr(
feature = "schemars",
schemars(description = "Symbol name of the caller or callee")
)]
pub symbol: String,
#[cfg_attr(
feature = "schemars",
schemars(description = "File path relative to the repository root")
)]
pub file: String,
#[cfg_attr(
feature = "schemars",
schemars(
description = "Line number of the definition or call site (1-indexed)",
schema_with = "crate::schema_helpers::integer_schema"
)
)]
pub line: usize,
}
#[derive(Debug, Serialize)]
#[cfg_attr(feature = "schemars", derive(JsonSchema))]
#[non_exhaustive]
pub struct FocusedAnalysisOutput {
#[cfg_attr(
feature = "schemars",
schemars(description = "Formatted text representation of the call graph analysis")
)]
pub formatted: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[cfg_attr(
feature = "schemars",
schemars(
description = "Opaque cursor token for the next page of results (absent when no more results)"
)
)]
pub next_cursor: Option<String>,
#[serde(skip)]
#[cfg_attr(feature = "schemars", schemars(skip))]
pub prod_chains: Vec<InternalCallChain>,
#[serde(skip)]
#[cfg_attr(feature = "schemars", schemars(skip))]
pub test_chains: Vec<InternalCallChain>,
#[serde(skip)]
#[cfg_attr(feature = "schemars", schemars(skip))]
pub outgoing_chains: Vec<InternalCallChain>,
#[serde(skip)]
#[cfg_attr(feature = "schemars", schemars(skip))]
pub def_count: usize,
#[serde(skip)]
#[cfg_attr(feature = "schemars", schemars(skip))]
pub unfiltered_caller_count: usize,
#[serde(skip)]
#[cfg_attr(feature = "schemars", schemars(skip))]
pub impl_trait_caller_count: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub callers: Option<Vec<CallChainEntry>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub test_callers: Option<Vec<CallChainEntry>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub callees: Option<Vec<CallChainEntry>>,
}
#[derive(Clone)]
pub struct FocusedAnalysisConfig {
pub focus: String,
pub match_mode: SymbolMatchMode,
pub follow_depth: u32,
pub max_depth: Option<u32>,
pub ast_recursion_limit: Option<usize>,
pub use_summary: bool,
pub impl_only: Option<bool>,
}
#[derive(Clone)]
struct FocusedAnalysisParams {
focus: String,
match_mode: SymbolMatchMode,
follow_depth: u32,
ast_recursion_limit: Option<usize>,
use_summary: bool,
impl_only: Option<bool>,
}
type AnalysisResults = (Vec<(PathBuf, SemanticAnalysis)>, Vec<ImplTraitInfo>);
fn collect_file_analysis(
entries: &[WalkEntry],
progress: &Arc<AtomicUsize>,
ct: &CancellationToken,
ast_recursion_limit: Option<usize>,
) -> Result<AnalysisResults, AnalyzeError> {
if ct.is_cancelled() {
return Err(AnalyzeError::Cancelled);
}
let file_entries: Vec<&WalkEntry> = entries.iter().filter(|e| !e.is_dir).collect();
let analysis_results: Vec<(PathBuf, SemanticAnalysis)> = file_entries
.par_iter()
.filter_map(|entry| {
if ct.is_cancelled() {
return None;
}
let ext = entry.path.extension().and_then(|e| e.to_str());
let Ok(source) = std::fs::read_to_string(&entry.path) else {
progress.fetch_add(1, Ordering::Relaxed);
return None;
};
let language = if let Some(ext_str) = ext {
language_for_extension(ext_str)
.map_or_else(|| "unknown".to_string(), std::string::ToString::to_string)
} else {
"unknown".to_string()
};
if let Ok(mut semantic) =
SemanticExtractor::extract(&source, &language, ast_recursion_limit)
{
for r in &mut semantic.references {
r.location = entry.path.display().to_string();
}
for trait_info in &mut semantic.impl_traits {
trait_info.path.clone_from(&entry.path);
}
progress.fetch_add(1, Ordering::Relaxed);
Some((entry.path.clone(), semantic))
} else {
progress.fetch_add(1, Ordering::Relaxed);
None
}
})
.collect();
if ct.is_cancelled() {
return Err(AnalyzeError::Cancelled);
}
let all_impl_traits: Vec<ImplTraitInfo> = analysis_results
.iter()
.flat_map(|(_, sem)| sem.impl_traits.iter().cloned())
.collect();
Ok((analysis_results, all_impl_traits))
}
fn build_call_graph(
analysis_results: Vec<(PathBuf, SemanticAnalysis)>,
all_impl_traits: &[ImplTraitInfo],
) -> Result<CallGraph, AnalyzeError> {
CallGraph::build_from_results(
analysis_results,
all_impl_traits,
false, )
.map_err(std::convert::Into::into)
}
fn resolve_symbol(
graph: &mut CallGraph,
params: &FocusedAnalysisParams,
) -> Result<(String, usize, usize), AnalyzeError> {
let resolved_focus = if params.match_mode == SymbolMatchMode::Exact {
let exists = graph.definitions.contains_key(¶ms.focus)
|| graph.callers.contains_key(¶ms.focus)
|| graph.callees.contains_key(¶ms.focus);
if exists {
params.focus.clone()
} else {
return Err(crate::graph::GraphError::SymbolNotFound {
symbol: params.focus.clone(),
hint: "Try match_mode=insensitive for a case-insensitive search, or match_mode=prefix to list symbols starting with this name.".to_string(),
}
.into());
}
} else {
graph.resolve_symbol_indexed(¶ms.focus, ¶ms.match_mode)?
};
let unfiltered_caller_count = graph.callers.get(&resolved_focus).map_or(0, |edges| {
edges
.iter()
.map(|e| &e.neighbor_name)
.collect::<std::collections::HashSet<_>>()
.len()
});
let impl_trait_caller_count = if params.impl_only.unwrap_or(false) {
for edges in graph.callers.values_mut() {
edges.retain(|e| e.is_impl_trait);
}
graph.callers.get(&resolved_focus).map_or(0, |edges| {
edges
.iter()
.map(|e| &e.neighbor_name)
.collect::<std::collections::HashSet<_>>()
.len()
})
} else {
unfiltered_caller_count
};
Ok((
resolved_focus,
unfiltered_caller_count,
impl_trait_caller_count,
))
}
type ChainComputeResult = (
String,
Vec<InternalCallChain>,
Vec<InternalCallChain>,
Vec<InternalCallChain>,
usize,
);
fn chains_to_entries(
chains: &[InternalCallChain],
root: Option<&std::path::Path>,
) -> Option<Vec<CallChainEntry>> {
if chains.is_empty() {
return None;
}
let entries: Vec<CallChainEntry> = chains
.iter()
.take(10)
.filter_map(|chain| {
let (symbol, path, line) = chain.chain.first()?;
let file = match root {
Some(root) => path
.strip_prefix(root)
.unwrap_or(path.as_path())
.to_string_lossy()
.into_owned(),
None => path.to_string_lossy().into_owned(),
};
Some(CallChainEntry {
symbol: symbol.clone(),
file,
line: *line,
})
})
.collect();
if entries.is_empty() {
None
} else {
Some(entries)
}
}
fn compute_chains(
graph: &CallGraph,
resolved_focus: &str,
root: &Path,
params: &FocusedAnalysisParams,
unfiltered_caller_count: usize,
impl_trait_caller_count: usize,
) -> Result<ChainComputeResult, AnalyzeError> {
let def_count = graph.definitions.get(resolved_focus).map_or(0, Vec::len);
let incoming_chains = graph.find_incoming_chains(resolved_focus, params.follow_depth)?;
let outgoing_chains = graph.find_outgoing_chains(resolved_focus, params.follow_depth)?;
let (prod_chains, test_chains): (Vec<_>, Vec<_>) =
incoming_chains.iter().cloned().partition(|chain| {
chain
.chain
.first()
.is_none_or(|(name, path, _)| !is_test_file(path) && !name.starts_with("test_"))
});
let mut formatted = if params.use_summary {
format_focused_summary_internal(
graph,
resolved_focus,
params.follow_depth,
Some(root),
Some(&incoming_chains),
Some(&outgoing_chains),
)?
} else {
format_focused_internal(
graph,
resolved_focus,
params.follow_depth,
Some(root),
Some(&incoming_chains),
Some(&outgoing_chains),
)?
};
if params.impl_only.unwrap_or(false) {
let filter_header = format!(
"FILTER: impl_only=true ({impl_trait_caller_count} of {unfiltered_caller_count} callers shown)\n",
);
formatted = format!("{filter_header}{formatted}");
}
Ok((
formatted,
prod_chains,
test_chains,
outgoing_chains,
def_count,
))
}
#[allow(clippy::needless_pass_by_value)]
pub fn analyze_focused_with_progress(
root: &Path,
params: &FocusedAnalysisConfig,
progress: Arc<AtomicUsize>,
ct: CancellationToken,
) -> Result<FocusedAnalysisOutput, AnalyzeError> {
let entries = walk_directory(root, params.max_depth)?;
let internal_params = FocusedAnalysisParams {
focus: params.focus.clone(),
match_mode: params.match_mode.clone(),
follow_depth: params.follow_depth,
ast_recursion_limit: params.ast_recursion_limit,
use_summary: params.use_summary,
impl_only: params.impl_only,
};
analyze_focused_with_progress_with_entries_internal(
root,
params.max_depth,
&progress,
&ct,
&internal_params,
&entries,
)
}
#[instrument(skip_all, fields(path = %root.display(), symbol = %params.focus))]
fn analyze_focused_with_progress_with_entries_internal(
root: &Path,
_max_depth: Option<u32>,
progress: &Arc<AtomicUsize>,
ct: &CancellationToken,
params: &FocusedAnalysisParams,
entries: &[WalkEntry],
) -> Result<FocusedAnalysisOutput, AnalyzeError> {
if ct.is_cancelled() {
return Err(AnalyzeError::Cancelled);
}
if root.is_file() {
let formatted =
"Single-file focus not supported. Please provide a directory path for cross-file call graph analysis.\n"
.to_string();
return Ok(FocusedAnalysisOutput {
formatted,
next_cursor: None,
prod_chains: vec![],
test_chains: vec![],
outgoing_chains: vec![],
def_count: 0,
unfiltered_caller_count: 0,
impl_trait_caller_count: 0,
callers: None,
test_callers: None,
callees: None,
});
}
let (analysis_results, all_impl_traits) =
collect_file_analysis(entries, progress, ct, params.ast_recursion_limit)?;
if ct.is_cancelled() {
return Err(AnalyzeError::Cancelled);
}
let mut graph = build_call_graph(analysis_results, &all_impl_traits)?;
if ct.is_cancelled() {
return Err(AnalyzeError::Cancelled);
}
let (resolved_focus, unfiltered_caller_count, impl_trait_caller_count) =
resolve_symbol(&mut graph, params)?;
if ct.is_cancelled() {
return Err(AnalyzeError::Cancelled);
}
let (formatted, prod_chains, test_chains, outgoing_chains, def_count) = compute_chains(
&graph,
&resolved_focus,
root,
params,
unfiltered_caller_count,
impl_trait_caller_count,
)?;
let (depth1_callers, depth1_test_callers, depth1_callees) = if params.follow_depth <= 1 {
let callers = chains_to_entries(&prod_chains, Some(root));
let test_callers = chains_to_entries(&test_chains, Some(root));
let callees = chains_to_entries(&outgoing_chains, Some(root));
(callers, test_callers, callees)
} else {
let incoming1 = graph
.find_incoming_chains(&resolved_focus, 1)
.unwrap_or_default();
let outgoing1 = graph
.find_outgoing_chains(&resolved_focus, 1)
.unwrap_or_default();
let (prod1, test1): (Vec<_>, Vec<_>) = incoming1.into_iter().partition(|chain| {
chain
.chain
.first()
.is_none_or(|(name, path, _)| !is_test_file(path) && !name.starts_with("test_"))
});
let callers = chains_to_entries(&prod1, Some(root));
let test_callers = chains_to_entries(&test1, Some(root));
let callees = chains_to_entries(&outgoing1, Some(root));
(callers, test_callers, callees)
};
Ok(FocusedAnalysisOutput {
formatted,
next_cursor: None,
callers: depth1_callers,
test_callers: depth1_test_callers,
callees: depth1_callees,
prod_chains,
test_chains,
outgoing_chains,
def_count,
unfiltered_caller_count,
impl_trait_caller_count,
})
}
pub fn analyze_focused_with_progress_with_entries(
root: &Path,
params: &FocusedAnalysisConfig,
progress: &Arc<AtomicUsize>,
ct: &CancellationToken,
entries: &[WalkEntry],
) -> Result<FocusedAnalysisOutput, AnalyzeError> {
let internal_params = FocusedAnalysisParams {
focus: params.focus.clone(),
match_mode: params.match_mode.clone(),
follow_depth: params.follow_depth,
ast_recursion_limit: params.ast_recursion_limit,
use_summary: params.use_summary,
impl_only: params.impl_only,
};
analyze_focused_with_progress_with_entries_internal(
root,
params.max_depth,
progress,
ct,
&internal_params,
entries,
)
}
#[instrument(skip_all, fields(path = %root.display(), symbol = %focus))]
pub fn analyze_focused(
root: &Path,
focus: &str,
follow_depth: u32,
max_depth: Option<u32>,
ast_recursion_limit: Option<usize>,
) -> Result<FocusedAnalysisOutput, AnalyzeError> {
let entries = walk_directory(root, max_depth)?;
let counter = Arc::new(AtomicUsize::new(0));
let ct = CancellationToken::new();
let params = FocusedAnalysisConfig {
focus: focus.to_string(),
match_mode: SymbolMatchMode::Exact,
follow_depth,
max_depth,
ast_recursion_limit,
use_summary: false,
impl_only: None,
};
analyze_focused_with_progress_with_entries(root, ¶ms, &counter, &ct, &entries)
}
#[instrument(skip_all, fields(path))]
pub fn analyze_module_file(path: &str) -> Result<crate::types::ModuleInfo, AnalyzeError> {
let source = std::fs::read_to_string(path)
.map_err(|e| AnalyzeError::Parser(crate::parser::ParserError::ParseError(e.to_string())))?;
let file_path = Path::new(path);
let name = file_path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
let line_count = source.lines().count();
let language = file_path
.extension()
.and_then(|e| e.to_str())
.and_then(language_for_extension)
.ok_or_else(|| {
AnalyzeError::Parser(crate::parser::ParserError::ParseError(
"unsupported or missing file extension".to_string(),
))
})?;
let semantic = SemanticExtractor::extract(&source, language, None)?;
let functions = semantic
.functions
.into_iter()
.map(|f| crate::types::ModuleFunctionInfo {
name: f.name,
line: f.line,
})
.collect();
let imports = semantic
.imports
.into_iter()
.map(|i| crate::types::ModuleImportInfo {
module: i.module,
items: i.items,
})
.collect();
Ok(crate::types::ModuleInfo {
name,
line_count,
language: language.to_string(),
functions,
imports,
})
}
fn resolve_wildcard_imports(file_path: &Path, imports: &mut [ImportInfo]) {
use std::collections::HashMap;
let mut resolved_cache: HashMap<PathBuf, Vec<String>> = HashMap::new();
let Ok(file_path_canonical) = file_path.canonicalize() else {
tracing::debug!(file = ?file_path, "unable to canonicalize current file path");
return;
};
for import in imports.iter_mut() {
if import.items != ["*"] {
continue;
}
resolve_single_wildcard(import, file_path, &file_path_canonical, &mut resolved_cache);
}
}
fn resolve_single_wildcard(
import: &mut ImportInfo,
file_path: &Path,
file_path_canonical: &Path,
resolved_cache: &mut std::collections::HashMap<PathBuf, Vec<String>>,
) {
let module = import.module.clone();
let dot_count = module.chars().take_while(|c| *c == '.').count();
if dot_count == 0 {
return;
}
let module_path = module.trim_start_matches('.');
let Some(target_to_read) = locate_target_file(file_path, dot_count, module_path, &module)
else {
return;
};
let Ok(canonical) = target_to_read.canonicalize() else {
tracing::debug!(target = ?target_to_read, import = %module, "unable to canonicalize path");
return;
};
if canonical == file_path_canonical {
tracing::debug!(target = ?canonical, import = %module, "cannot import from self");
return;
}
if let Some(cached) = resolved_cache.get(&canonical) {
tracing::debug!(import = %module, symbols_count = cached.len(), "using cached symbols");
import.items.clone_from(cached);
return;
}
if let Some(symbols) = parse_target_symbols(&target_to_read, &module) {
tracing::debug!(import = %module, resolved_count = symbols.len(), "wildcard import resolved");
import.items.clone_from(&symbols);
resolved_cache.insert(canonical, symbols);
}
}
fn locate_target_file(
file_path: &Path,
dot_count: usize,
module_path: &str,
module: &str,
) -> Option<PathBuf> {
let mut target_dir = file_path.parent()?.to_path_buf();
for _ in 1..dot_count {
if !target_dir.pop() {
tracing::debug!(import = %module, "unable to climb {} levels", dot_count.saturating_sub(1));
return None;
}
}
let target_file = if module_path.is_empty() {
target_dir.join("__init__.py")
} else {
let rel_path = module_path.replace('.', "/");
target_dir.join(format!("{rel_path}.py"))
};
if target_file.exists() {
Some(target_file)
} else if target_file.with_extension("").is_dir() {
let init = target_file.with_extension("").join("__init__.py");
if init.exists() { Some(init) } else { None }
} else {
tracing::debug!(target = ?target_file, import = %module, "target file not found");
None
}
}
fn parse_target_symbols(target_path: &Path, module: &str) -> Option<Vec<String>> {
use tree_sitter::Parser;
let source = match std::fs::read_to_string(target_path) {
Ok(s) => s,
Err(e) => {
tracing::debug!(target = ?target_path, import = %module, error = %e, "unable to read target file");
return None;
}
};
let lang_info = crate::languages::get_language_info("python")?;
let mut parser = Parser::new();
if parser.set_language(&lang_info.language).is_err() {
return None;
}
let tree = parser.parse(&source, None)?;
let mut symbols = Vec::new();
extract_all_from_tree(&tree, &source, &mut symbols);
if !symbols.is_empty() {
tracing::debug!(import = %module, symbols = ?symbols, "using __all__ symbols");
return Some(symbols);
}
let root = tree.root_node();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if matches!(child.kind(), "function_definition" | "class_definition")
&& let Some(name_node) = child.child_by_field_name("name")
{
let name = source[name_node.start_byte()..name_node.end_byte()].to_string();
if !name.starts_with('_') {
symbols.push(name);
}
}
}
tracing::debug!(import = %module, fallback_symbols = ?symbols, "using fallback function/class names");
Some(symbols)
}
fn extract_all_from_tree(tree: &tree_sitter::Tree, source: &str, result: &mut Vec<String>) {
let root = tree.root_node();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if child.kind() == "simple_statement" {
let mut simple_cursor = child.walk();
for simple_child in child.children(&mut simple_cursor) {
if simple_child.kind() == "assignment"
&& let Some(left) = simple_child.child_by_field_name("left")
{
let target_text = source[left.start_byte()..left.end_byte()].trim();
if target_text == "__all__"
&& let Some(right) = simple_child.child_by_field_name("right")
{
extract_string_list_from_list_node(&right, source, result);
}
}
}
} else if child.kind() == "expression_statement" {
let mut stmt_cursor = child.walk();
for stmt_child in child.children(&mut stmt_cursor) {
if stmt_child.kind() == "assignment"
&& let Some(left) = stmt_child.child_by_field_name("left")
{
let target_text = source[left.start_byte()..left.end_byte()].trim();
if target_text == "__all__"
&& let Some(right) = stmt_child.child_by_field_name("right")
{
extract_string_list_from_list_node(&right, source, result);
}
}
}
}
}
}
fn extract_string_list_from_list_node(
list_node: &tree_sitter::Node,
source: &str,
result: &mut Vec<String>,
) {
let mut cursor = list_node.walk();
for child in list_node.named_children(&mut cursor) {
if child.kind() == "string" {
let raw = source[child.start_byte()..child.end_byte()].trim();
let unquoted = raw.trim_matches('"').trim_matches('\'').to_string();
if !unquoted.is_empty() {
result.push(unquoted);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::formatter::format_focused_paginated;
use crate::graph::InternalCallChain;
use crate::pagination::{PaginationMode, decode_cursor, paginate_slice};
use std::fs;
use std::path::PathBuf;
use tempfile::TempDir;
#[cfg(feature = "lang-rust")]
#[test]
fn analyze_str_rust_happy_path() {
let source = "fn hello() -> i32 { 42 }";
let result = analyze_str(source, "rs", None);
assert!(result.is_ok());
}
#[cfg(feature = "lang-python")]
#[test]
fn analyze_str_python_happy_path() {
let source = "def greet(name):\n return f'Hello {name}'";
let result = analyze_str(source, "py", None);
assert!(result.is_ok());
}
#[cfg(feature = "lang-rust")]
#[test]
fn analyze_str_rust_by_language_name() {
let source = "fn hello() -> i32 { 42 }";
let result = analyze_str(source, "rust", None);
assert!(result.is_ok());
}
#[cfg(feature = "lang-python")]
#[test]
fn analyze_str_python_by_language_name() {
let source = "def greet(name):\n return f'Hello {name}'";
let result = analyze_str(source, "python", None);
assert!(result.is_ok());
}
#[cfg(feature = "lang-rust")]
#[test]
fn analyze_str_rust_mixed_case() {
let source = "fn hello() -> i32 { 42 }";
let result = analyze_str(source, "RuSt", None);
assert!(result.is_ok());
}
#[cfg(feature = "lang-python")]
#[test]
fn analyze_str_python_mixed_case() {
let source = "def greet(name):\n return f'Hello {name}'";
let result = analyze_str(source, "PyThOn", None);
assert!(result.is_ok());
}
#[test]
fn analyze_str_unsupported_language() {
let result = analyze_str("code", "brainfuck", None);
assert!(
matches!(result, Err(AnalyzeError::UnsupportedLanguage(lang)) if lang == "brainfuck")
);
}
#[cfg(feature = "lang-rust")]
#[test]
fn test_symbol_focus_callers_pagination_first_page() {
let temp_dir = TempDir::new().unwrap();
let mut code = String::from("fn target() {}\n");
for i in 0..15 {
code.push_str(&format!("fn caller_{:02}() {{ target(); }}\n", i));
}
fs::write(temp_dir.path().join("lib.rs"), &code).unwrap();
let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
let paginated = paginate_slice(&output.prod_chains, 0, 5, PaginationMode::Callers)
.expect("paginate failed");
assert!(
paginated.total >= 5,
"should have enough callers to paginate"
);
assert!(
paginated.next_cursor.is_some(),
"should have next_cursor for page 1"
);
assert_eq!(paginated.items.len(), 5);
}
#[test]
fn test_symbol_focus_callers_pagination_second_page() {
let temp_dir = TempDir::new().unwrap();
let mut code = String::from("fn target() {}\n");
for i in 0..12 {
code.push_str(&format!("fn caller_{:02}() {{ target(); }}\n", i));
}
fs::write(temp_dir.path().join("lib.rs"), &code).unwrap();
let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
let total_prod = output.prod_chains.len();
if total_prod > 5 {
let p1 = paginate_slice(&output.prod_chains, 0, 5, PaginationMode::Callers)
.expect("paginate failed");
assert!(p1.next_cursor.is_some());
let cursor_str = p1.next_cursor.unwrap();
let cursor_data = decode_cursor(&cursor_str).expect("decode failed");
let p2 = paginate_slice(
&output.prod_chains,
cursor_data.offset,
5,
PaginationMode::Callers,
)
.expect("paginate failed");
let formatted = format_focused_paginated(
&p2.items,
total_prod,
PaginationMode::Callers,
"target",
&output.prod_chains,
&output.test_chains,
&output.outgoing_chains,
output.def_count,
cursor_data.offset,
Some(temp_dir.path()),
true,
);
let expected_start = cursor_data.offset + 1;
assert!(
formatted.contains(&format!("CALLERS ({}", expected_start)),
"header should show page 2 range, got: {}",
formatted
);
}
}
#[test]
fn test_chains_to_entries_empty_returns_none() {
let chains: Vec<InternalCallChain> = vec![];
let result = chains_to_entries(&chains, None);
assert!(result.is_none());
}
#[test]
fn test_chains_to_entries_with_data_returns_entries() {
let chains = vec![
InternalCallChain {
chain: vec![("caller1".to_string(), PathBuf::from("/root/lib.rs"), 10)],
},
InternalCallChain {
chain: vec![("caller2".to_string(), PathBuf::from("/root/other.rs"), 20)],
},
];
let root = PathBuf::from("/root");
let result = chains_to_entries(&chains, Some(root.as_path()));
assert!(result.is_some());
let entries = result.unwrap();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].symbol, "caller1");
assert_eq!(entries[0].file, "lib.rs");
assert_eq!(entries[0].line, 10);
assert_eq!(entries[1].symbol, "caller2");
assert_eq!(entries[1].file, "other.rs");
assert_eq!(entries[1].line, 20);
}
#[test]
fn test_symbol_focus_callees_pagination() {
let temp_dir = TempDir::new().unwrap();
let mut code = String::from("fn target() {\n");
for i in 0..10 {
code.push_str(&format!(" callee_{:02}();\n", i));
}
code.push_str("}\n");
for i in 0..10 {
code.push_str(&format!("fn callee_{:02}() {{}}\n", i));
}
fs::write(temp_dir.path().join("lib.rs"), &code).unwrap();
let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
let total_callees = output.outgoing_chains.len();
if total_callees > 3 {
let paginated = paginate_slice(&output.outgoing_chains, 0, 3, PaginationMode::Callees)
.expect("paginate failed");
let formatted = format_focused_paginated(
&paginated.items,
total_callees,
PaginationMode::Callees,
"target",
&output.prod_chains,
&output.test_chains,
&output.outgoing_chains,
output.def_count,
0,
Some(temp_dir.path()),
true,
);
assert!(
formatted.contains(&format!(
"CALLEES (1-{} of {})",
paginated.items.len(),
total_callees
)),
"header should show callees range, got: {}",
formatted
);
}
}
#[test]
fn test_symbol_focus_empty_prod_callers() {
let temp_dir = TempDir::new().unwrap();
let code = r#"
fn target() {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_something() { target(); }
}
"#;
fs::write(temp_dir.path().join("lib.rs"), code).unwrap();
let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
let paginated = paginate_slice(&output.prod_chains, 0, 100, PaginationMode::Callers)
.expect("paginate failed");
assert_eq!(paginated.items.len(), output.prod_chains.len());
assert!(
paginated.next_cursor.is_none(),
"no next_cursor for empty or single-page prod_chains"
);
}
#[test]
fn test_impl_only_filter_header_correct_counts() {
let temp_dir = TempDir::new().unwrap();
let code = r#"
trait MyTrait {
fn focus_symbol();
}
struct SomeType;
impl MyTrait for SomeType {
fn focus_symbol() {}
}
fn impl_caller() {
SomeType::focus_symbol();
}
fn regular_caller() {
SomeType::focus_symbol();
}
"#;
fs::write(temp_dir.path().join("lib.rs"), code).unwrap();
let params = FocusedAnalysisConfig {
focus: "focus_symbol".to_string(),
match_mode: SymbolMatchMode::Insensitive,
follow_depth: 1,
max_depth: None,
ast_recursion_limit: None,
use_summary: false,
impl_only: Some(true),
};
let output = analyze_focused_with_progress(
temp_dir.path(),
¶ms,
Arc::new(AtomicUsize::new(0)),
CancellationToken::new(),
)
.unwrap();
assert!(
output.formatted.contains("FILTER: impl_only=true"),
"formatted output should contain FILTER header for impl_only=true, got: {}",
output.formatted
);
assert!(
output.impl_trait_caller_count < output.unfiltered_caller_count,
"impl_trait_caller_count ({}) should be less than unfiltered_caller_count ({})",
output.impl_trait_caller_count,
output.unfiltered_caller_count
);
let filter_line = output
.formatted
.lines()
.find(|line| line.contains("FILTER: impl_only=true"))
.expect("should find FILTER line");
assert!(
filter_line.contains(&format!(
"({} of {} callers shown)",
output.impl_trait_caller_count, output.unfiltered_caller_count
)),
"FILTER line should show correct N of M counts, got: {}",
filter_line
);
}
#[test]
fn test_callers_count_matches_formatted_output() {
let temp_dir = TempDir::new().unwrap();
let code = r#"
fn target() {}
fn caller_a() { target(); }
fn caller_b() { target(); }
fn caller_c() { target(); }
"#;
fs::write(temp_dir.path().join("lib.rs"), code).unwrap();
let output = analyze_focused(temp_dir.path(), "target", 1, None, None).unwrap();
let formatted = &output.formatted;
let callers_count_from_output = formatted
.lines()
.find(|line| line.contains("FOCUS:"))
.and_then(|line| {
line.split(',')
.find(|part| part.contains("callers"))
.and_then(|part| {
part.trim()
.split_whitespace()
.next()
.and_then(|s| s.parse::<usize>().ok())
})
})
.expect("should find CALLERS count in formatted output");
let expected_callers_count = output
.prod_chains
.iter()
.filter_map(|chain| chain.chain.first().map(|(name, _, _)| name))
.collect::<std::collections::HashSet<_>>()
.len();
assert_eq!(
callers_count_from_output, expected_callers_count,
"CALLERS count in formatted output should match unique-first-caller count in prod_chains"
);
}
}