use std::collections::HashSet;
use std::fmt::Write as _;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use streaming_iterator::StreamingIterator;
use crate::lang::treesitter::{extract_definition_name, DEFINITION_KINDS};
use crate::cache::OutlineCache;
use crate::error::SrcwalkError;
use crate::format::rel_nonempty;
use crate::lang::detect_file_type;
use crate::lang::outline::outline_language;
use crate::session::Session;
use crate::types::FileType;
const IMPACT_FANOUT_THRESHOLD: usize = 10;
const IMPACT_MAX_RESULTS: usize = 15;
const BATCH_EARLY_QUIT: usize = 50;
pub(super) const TOP_LEVEL: &str = "<top-level>";
#[derive(Debug)]
pub struct CallerMatch {
pub path: PathBuf,
pub line: u32,
pub calling_function: String,
pub call_text: String,
pub caller_range: Option<(u32, u32)>,
pub content: Arc<String>,
}
pub fn find_callers(
target: &str,
scope: &Path,
bloom: &crate::index::bloom::BloomFilterCache,
glob: Option<&str>,
cache: Option<&crate::cache::OutlineCache>,
) -> Result<Vec<CallerMatch>, SrcwalkError> {
let matches: Mutex<Vec<CallerMatch>> = Mutex::new(Vec::new());
let found_count = AtomicUsize::new(0);
let needle = target.as_bytes();
let walker = crate::search::walker(scope, glob)?;
walker.run(|| {
let matches = &matches;
let found_count = &found_count;
Box::new(move |entry| {
let Ok(entry) = entry else {
return ignore::WalkState::Continue;
};
if !entry.file_type().is_some_and(|ft| ft.is_file()) {
return ignore::WalkState::Continue;
}
let path = entry.path();
let (file_len, mtime) = match std::fs::metadata(path) {
Ok(meta) => (
meta.len(),
meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH),
),
Err(_) => return ignore::WalkState::Continue,
};
if file_len > 500_000 {
return ignore::WalkState::Continue;
}
if crate::search::io::is_minified_filename(path) {
return ignore::WalkState::Continue;
}
let Some(bytes) = crate::search::read_file_bytes(path, file_len) else {
return ignore::WalkState::Continue;
};
if memchr::memmem::find(&bytes, needle).is_none() {
return ignore::WalkState::Continue;
}
if file_len >= crate::search::io::MINIFIED_CHECK_THRESHOLD
&& crate::search::io::looks_minified(&bytes)
{
return ignore::WalkState::Continue;
}
let Ok(content) = std::str::from_utf8(&bytes) else {
return ignore::WalkState::Continue;
};
if !bloom.contains(path, mtime, content, target) {
return ignore::WalkState::Continue;
}
let file_type = detect_file_type(path);
let FileType::Code(lang) = file_type else {
return ignore::WalkState::Continue;
};
let Some(ts_lang) = outline_language(lang) else {
return ignore::WalkState::Continue;
};
let file_callers =
find_callers_treesitter(path, target, &ts_lang, content, lang, mtime, cache);
if !file_callers.is_empty() {
found_count.fetch_add(file_callers.len(), Ordering::Relaxed);
let mut all = matches
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
all.extend(file_callers);
}
ignore::WalkState::Continue
})
});
Ok(matches
.into_inner()
.unwrap_or_else(std::sync::PoisonError::into_inner))
}
fn find_callers_treesitter(
path: &Path,
target: &str,
ts_lang: &tree_sitter::Language,
content: &str,
lang: crate::types::Lang,
mtime: std::time::SystemTime,
cache: Option<&crate::cache::OutlineCache>,
) -> Vec<CallerMatch> {
let Some(query_str) = crate::search::callees::callee_query_str(lang) else {
return Vec::new();
};
let tree = if let Some(c) = cache {
let Some(tree) = c.get_or_parse(path, mtime, content, ts_lang) else {
return Vec::new();
};
tree
} else {
let mut parser = tree_sitter::Parser::new();
if parser.set_language(ts_lang).is_err() {
return Vec::new();
}
let Some(tree) = parser.parse(content, None) else {
return Vec::new();
};
tree
};
let content_bytes = content.as_bytes();
let lines: Vec<&str> = content.lines().collect();
let shared_content: Arc<String> = Arc::new(content.to_string());
let Some(callers) = crate::search::callees::with_callee_query(ts_lang, query_str, |query| {
let Some(callee_idx) = query.capture_index_for_name("callee") else {
return Vec::new();
};
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), content_bytes);
let mut callers = Vec::new();
while let Some(m) = matches.next() {
for cap in m.captures {
if cap.index != callee_idx {
continue;
}
let Ok(text) = cap.node.utf8_text(content_bytes) else {
continue;
};
if text != target {
continue;
}
let line = cap.node.start_position().row as u32 + 1;
let call_node = cap.node.parent().unwrap_or(cap.node);
let same_line = call_node.start_position().row == call_node.end_position().row;
let call_text: String = if same_line {
let row = call_node.start_position().row;
if row < lines.len() {
lines[row].trim().to_string()
} else {
text.to_string()
}
} else {
text.to_string()
};
let (calling_function, caller_range) =
find_enclosing_function(cap.node, &lines, lang);
callers.push(CallerMatch {
path: path.to_path_buf(),
line,
calling_function,
call_text,
caller_range,
content: Arc::clone(&shared_content),
});
}
}
callers
}) else {
return Vec::new();
};
callers
}
pub(crate) fn find_callers_batch(
targets: &HashSet<String>,
scope: &Path,
bloom: &crate::index::bloom::BloomFilterCache,
glob: Option<&str>,
cache: Option<&crate::cache::OutlineCache>,
early_quit: Option<usize>,
) -> Result<Vec<(String, CallerMatch)>, SrcwalkError> {
let matches: Mutex<Vec<(String, CallerMatch)>> = Mutex::new(Vec::new());
let found_count = AtomicUsize::new(0);
let target_vec: Vec<&str> = targets.iter().map(String::as_str).collect();
let ac = if target_vec.len() >= 3 {
aho_corasick::AhoCorasick::new(&target_vec).ok()
} else {
None
};
let mut sorted_targets: Vec<&str> = target_vec.clone();
sorted_targets.sort_by_key(|t| std::cmp::Reverse(t.len()));
let walker = crate::search::walker(scope, glob)?;
walker.run(|| {
let matches = &matches;
let found_count = &found_count;
let ac = ac.as_ref();
let sorted_targets = &sorted_targets;
Box::new(move |entry| {
if let Some(cap) = early_quit {
if found_count.load(Ordering::Relaxed) >= cap {
return ignore::WalkState::Quit;
}
}
let Ok(entry) = entry else {
return ignore::WalkState::Continue;
};
if !entry.file_type().is_some_and(|ft| ft.is_file()) {
return ignore::WalkState::Continue;
}
let path = entry.path();
let (file_len, mtime) = match std::fs::metadata(path) {
Ok(meta) => (
meta.len(),
meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH),
),
Err(_) => return ignore::WalkState::Continue,
};
if file_len > 500_000 {
return ignore::WalkState::Continue;
}
if crate::search::io::is_minified_filename(path) {
return ignore::WalkState::Continue;
}
let Some(bytes) = crate::search::read_file_bytes(path, file_len) else {
return ignore::WalkState::Continue;
};
let any_match = if let Some(ac) = ac {
ac.is_match(&*bytes)
} else {
sorted_targets
.iter()
.any(|t| memchr::memmem::find(&bytes, t.as_bytes()).is_some())
};
if !any_match {
return ignore::WalkState::Continue;
}
if file_len >= crate::search::io::MINIFIED_CHECK_THRESHOLD
&& crate::search::io::looks_minified(&bytes)
{
return ignore::WalkState::Continue;
}
let Ok(content) = std::str::from_utf8(&bytes) else {
return ignore::WalkState::Continue;
};
if !targets
.iter()
.any(|t| bloom.contains(path, mtime, content, t))
{
return ignore::WalkState::Continue;
}
let file_type = detect_file_type(path);
let FileType::Code(lang) = file_type else {
return ignore::WalkState::Continue;
};
let Some(ts_lang) = outline_language(lang) else {
return ignore::WalkState::Continue;
};
let file_callers =
find_callers_treesitter_batch(path, targets, &ts_lang, content, lang, mtime, cache);
if !file_callers.is_empty() {
found_count.fetch_add(file_callers.len(), Ordering::Relaxed);
let mut all = matches
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
all.extend(file_callers);
}
ignore::WalkState::Continue
})
});
Ok(matches
.into_inner()
.unwrap_or_else(std::sync::PoisonError::into_inner))
}
fn find_callers_treesitter_batch(
path: &Path,
targets: &HashSet<String>,
ts_lang: &tree_sitter::Language,
content: &str,
lang: crate::types::Lang,
mtime: std::time::SystemTime,
cache: Option<&crate::cache::OutlineCache>,
) -> Vec<(String, CallerMatch)> {
let Some(query_str) = crate::search::callees::callee_query_str(lang) else {
return Vec::new();
};
let tree = if let Some(c) = cache {
let Some(tree) = c.get_or_parse(path, mtime, content, ts_lang) else {
return Vec::new();
};
tree
} else {
let mut parser = tree_sitter::Parser::new();
if parser.set_language(ts_lang).is_err() {
return Vec::new();
}
let Some(tree) = parser.parse(content, None) else {
return Vec::new();
};
tree
};
let content_bytes = content.as_bytes();
let lines: Vec<&str> = content.lines().collect();
let shared_content: Arc<String> = Arc::new(content.to_string());
let Some(callers) = crate::search::callees::with_callee_query(ts_lang, query_str, |query| {
let Some(callee_idx) = query.capture_index_for_name("callee") else {
return Vec::new();
};
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), content_bytes);
let mut callers = Vec::new();
while let Some(m) = matches.next() {
for cap in m.captures {
if cap.index != callee_idx {
continue;
}
let Ok(text) = cap.node.utf8_text(content_bytes) else {
continue;
};
if !targets.contains(text) {
continue;
}
let matched_target = text.to_string();
let line = cap.node.start_position().row as u32 + 1;
let call_node = cap.node.parent().unwrap_or(cap.node);
let same_line = call_node.start_position().row == call_node.end_position().row;
let call_text: String = if same_line {
let row = call_node.start_position().row;
if row < lines.len() {
lines[row].trim().to_string()
} else {
matched_target.clone()
}
} else {
matched_target.clone()
};
let (calling_function, caller_range) =
find_enclosing_function(cap.node, &lines, lang);
callers.push((
matched_target,
CallerMatch {
path: path.to_path_buf(),
line,
calling_function,
call_text,
caller_range,
content: Arc::clone(&shared_content),
},
));
}
}
callers
}) else {
return Vec::new();
};
callers
}
const TYPE_KINDS: &[&str] = &[
"class_declaration",
"class_definition",
"struct_item",
"impl_item",
"interface_declaration",
"trait_item",
"trait_declaration",
"type_declaration",
"enum_item",
"enum_declaration",
"module",
"mod_item",
"namespace_definition",
];
fn find_enclosing_function(
node: tree_sitter::Node,
lines: &[&str],
lang: crate::types::Lang,
) -> (String, Option<(u32, u32)>) {
let mut current = Some(node);
while let Some(n) = current {
let kind = n.kind();
let def_name = if DEFINITION_KINDS.contains(&kind) {
extract_definition_name(n, lines)
} else if lang == crate::types::Lang::Elixir
&& crate::lang::treesitter::is_elixir_definition(n, lines)
{
crate::lang::treesitter::extract_elixir_definition_name(n, lines)
} else {
None
};
if let Some(name) = def_name {
let range = Some((
n.start_position().row as u32 + 1,
n.end_position().row as u32 + 1,
));
let mut parent = n.parent();
while let Some(p) = parent {
if TYPE_KINDS.contains(&p.kind()) {
if let Some(type_name) = extract_definition_name(p, lines) {
return (format!("{type_name}.{name}"), range);
}
}
if lang == crate::types::Lang::Elixir
&& crate::lang::treesitter::is_elixir_definition(p, lines)
{
if let Some(type_name) =
crate::lang::treesitter::extract_elixir_definition_name(p, lines)
{
return (format!("{type_name}.{name}"), range);
}
}
parent = p.parent();
}
return (name, range);
}
current = n.parent();
}
("<top-level>".to_string(), None)
}
pub fn search_callers_expanded(
target: &str,
scope: &Path,
cache: &OutlineCache,
_session: &Session,
bloom: &crate::index::bloom::BloomFilterCache,
expand: usize,
context: Option<&Path>,
limit: Option<usize>,
offset: usize,
glob: Option<&str>,
) -> Result<String, SrcwalkError> {
let max_matches = limit.unwrap_or(usize::MAX);
let callers = find_callers(target, scope, bloom, glob, Some(cache))?;
if callers.is_empty() {
return Ok(format!(
"# Callers of \"{}\" in {} — no call sites found\n\n\
Tip: srcwalk detects only direct, by-name call sites. The symbol may still be invoked via:\n\
- Rust trait objects (`dyn Trait`) or generic bounds\n\
- Go interface dispatch or function values stored in structs\n\
- Java/Kotlin interface or abstract methods, reflection\n\
- TypeScript/JS class hierarchies, callbacks, or dynamic property access\n\
- Python duck typing, `getattr`, decorators\n\n\
Try `srcwalk(\"{}\")` (symbol search) to find the declaring interface/trait, \
then run `callers` on that name, or search for implementors.",
target,
scope.display(),
target,
));
}
let mut sorted_callers = callers;
rank_callers(&mut sorted_callers, scope, context);
let total = sorted_callers.len();
let all_caller_names: HashSet<String> = sorted_callers
.iter()
.filter(|c| c.calling_function != "<top-level>")
.map(|c| c.calling_function.clone())
.collect();
let effective_offset = offset.min(total);
if effective_offset > 0 {
sorted_callers.drain(..effective_offset);
}
sorted_callers.truncate(max_matches);
let shown = sorted_callers.len();
let mut output = format!(
"# Callers of \"{}\" in {} — {} call site{}\n",
target,
scope.display(),
total,
if total == 1 { "" } else { "s" }
);
for (i, caller) in sorted_callers.iter().enumerate() {
let _ = write!(
output,
"\n## {}:{} [caller: {}]\n",
rel_nonempty(&caller.path, scope),
caller.line,
caller.calling_function
);
let _ = writeln!(output, "→ {}", caller.call_text);
if i < expand {
if let Some((start, end)) = caller.caller_range {
let lines: Vec<&str> = caller.content.lines().collect();
let window_start = caller.line.saturating_sub(2).max(start);
let window_end = (caller.line + 2).min(end);
let start_idx = (window_start as usize).saturating_sub(1);
let end_idx = (window_end as usize).min(lines.len());
output.push('\n');
output.push_str("```\n");
for (idx, line) in lines[start_idx..end_idx].iter().enumerate() {
let line_num = start_idx + idx + 1;
let prefix = if line_num == caller.line as usize {
"► "
} else {
" "
};
let _ = writeln!(output, "{prefix}{line_num:4} │ {line}");
}
output.push_str("```\n");
}
}
}
if total > effective_offset + shown {
let omitted = total - effective_offset - shown;
let next_offset = effective_offset + shown;
let _ = write!(
output,
"\n... and {omitted} more call sites. Next page: --offset {next_offset}."
);
} else if effective_offset > 0 {
let _ = write!(output, "\n(end of results, offset={effective_offset})");
}
if !all_caller_names.is_empty() && all_caller_names.len() <= IMPACT_FANOUT_THRESHOLD {
if let Ok(hop2) = find_callers_batch(
&all_caller_names,
scope,
bloom,
glob,
Some(cache),
Some(BATCH_EARLY_QUIT),
) {
let hop1_locations: HashSet<(PathBuf, u32)> = sorted_callers
.iter()
.map(|c| (c.path.clone(), c.line))
.collect();
let hop2_filtered: Vec<_> = hop2
.into_iter()
.filter(|(_, m)| !hop1_locations.contains(&(m.path.clone(), m.line)))
.collect();
if !hop2_filtered.is_empty() {
output.push_str("\n── impact (2nd hop) ──\n");
let mut seen: HashSet<(String, PathBuf)> = HashSet::new();
let mut count = 0;
for (via, m) in &hop2_filtered {
let key = (m.calling_function.clone(), m.path.clone());
if !seen.insert(key) {
continue;
}
if count >= IMPACT_MAX_RESULTS {
break;
}
let rel_path = rel_nonempty(&m.path, scope);
let _ = writeln!(
output,
" {:<20} {}:{} \u{2192} {}",
m.calling_function, rel_path, m.line, via
);
count += 1;
}
let unique_total = hop2_filtered
.iter()
.map(|(_, m)| (&m.calling_function, &m.path))
.collect::<HashSet<_>>()
.len();
if unique_total > IMPACT_MAX_RESULTS {
let _ = writeln!(
output,
" ... and {} more",
unique_total - IMPACT_MAX_RESULTS
);
}
let _ = writeln!(
output,
"\n{} functions affected across 2 hops.",
sorted_callers.len() + count
);
}
}
}
let tokens = crate::types::estimate_tokens(output.len() as u64);
let token_str = if tokens >= 1000 {
format!("~{}.{}k", tokens / 1000, (tokens % 1000) / 100)
} else {
format!("~{tokens}")
};
let _ = write!(output, "\n\n({token_str} tokens)");
Ok(output)
}
fn rank_callers(callers: &mut [CallerMatch], scope: &Path, context: Option<&Path>) {
callers.sort_by(|a, b| {
if let Some(ctx) = context {
match (a.path == ctx, b.path == ctx) {
(true, false) => return std::cmp::Ordering::Less,
(false, true) => return std::cmp::Ordering::Greater,
_ => {}
}
}
let a_rel = a.path.strip_prefix(scope).unwrap_or(&a.path);
let b_rel = b.path.strip_prefix(scope).unwrap_or(&b.path);
a_rel
.components()
.count()
.cmp(&b_rel.components().count())
.then_with(|| a.path.cmp(&b.path))
.then_with(|| a.line.cmp(&b.line))
});
}