use std::collections::HashSet;
use std::fmt::Write as _;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use streaming_iterator::StreamingIterator;
use crate::lang::treesitter::{extract_definition_name, DEFINITION_KINDS};
use crate::cache::OutlineCache;
use crate::error::SrcwalkError;
use crate::format::rel_nonempty;
use crate::lang::detect_file_type;
use crate::lang::outline::outline_language;
use crate::session::Session;
use crate::types::FileType;
const IMPACT_FANOUT_THRESHOLD: usize = 10;
const IMPACT_MAX_RESULTS: usize = 15;
const BATCH_EARLY_QUIT: usize = 50;
pub(super) const TOP_LEVEL: &str = "<top-level>";
#[derive(Debug)]
pub struct CallerMatch {
pub path: PathBuf,
pub line: u32,
pub calling_function: String,
pub call_text: String,
pub caller_range: Option<(u32, u32)>,
pub receiver: Option<String>,
pub arg_count: Option<u8>,
pub content: Arc<String>,
}
pub fn find_callers(
target: &str,
scope: &Path,
bloom: &crate::index::bloom::BloomFilterCache,
glob: Option<&str>,
cache: Option<&crate::cache::OutlineCache>,
) -> Result<Vec<CallerMatch>, SrcwalkError> {
let matches: Mutex<Vec<CallerMatch>> = Mutex::new(Vec::new());
let found_count = AtomicUsize::new(0);
let needle = target.as_bytes();
let walker = crate::search::walker(scope, glob)?;
walker.run(|| {
let matches = &matches;
let found_count = &found_count;
Box::new(move |entry| {
let Ok(entry) = entry else {
return ignore::WalkState::Continue;
};
if !entry.file_type().is_some_and(|ft| ft.is_file()) {
return ignore::WalkState::Continue;
}
let path = entry.path();
let (file_len, mtime) = match std::fs::metadata(path) {
Ok(meta) => (
meta.len(),
meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH),
),
Err(_) => return ignore::WalkState::Continue,
};
if file_len > 500_000 {
return ignore::WalkState::Continue;
}
if crate::search::io::is_minified_filename(path) {
return ignore::WalkState::Continue;
}
let Some(bytes) = crate::search::read_file_bytes(path, file_len) else {
return ignore::WalkState::Continue;
};
if memchr::memmem::find(&bytes, needle).is_none() {
return ignore::WalkState::Continue;
}
if file_len >= crate::search::io::MINIFIED_CHECK_THRESHOLD
&& crate::search::io::looks_minified(&bytes)
{
return ignore::WalkState::Continue;
}
let Ok(content) = std::str::from_utf8(&bytes) else {
return ignore::WalkState::Continue;
};
if !bloom.contains(path, mtime, content, target) {
return ignore::WalkState::Continue;
}
let file_type = detect_file_type(path);
let FileType::Code(lang) = file_type else {
return ignore::WalkState::Continue;
};
let Some(ts_lang) = outline_language(lang) else {
return ignore::WalkState::Continue;
};
let file_callers =
find_callers_treesitter(path, target, &ts_lang, content, lang, mtime, cache);
if !file_callers.is_empty() {
found_count.fetch_add(file_callers.len(), Ordering::Relaxed);
let mut all = matches
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
all.extend(file_callers);
}
ignore::WalkState::Continue
})
});
Ok(matches
.into_inner()
.unwrap_or_else(std::sync::PoisonError::into_inner))
}
fn find_callers_treesitter(
path: &Path,
target: &str,
ts_lang: &tree_sitter::Language,
content: &str,
lang: crate::types::Lang,
mtime: std::time::SystemTime,
cache: Option<&crate::cache::OutlineCache>,
) -> Vec<CallerMatch> {
let Some(query_str) = crate::search::callees::callee_query_str(lang) else {
return Vec::new();
};
let tree = if let Some(c) = cache {
let Some(tree) = c.get_or_parse(path, mtime, content, ts_lang) else {
return Vec::new();
};
tree
} else {
let mut parser = tree_sitter::Parser::new();
if parser.set_language(ts_lang).is_err() {
return Vec::new();
}
let Some(tree) = parser.parse(content, None) else {
return Vec::new();
};
tree
};
let content_bytes = content.as_bytes();
let lines: Vec<&str> = content.lines().collect();
let shared_content: Arc<String> = Arc::new(content.to_string());
let Some(callers) = crate::search::callees::with_callee_query(ts_lang, query_str, |query| {
let Some(callee_idx) = query.capture_index_for_name("callee") else {
return Vec::new();
};
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), content_bytes);
let mut callers = Vec::new();
while let Some(m) = matches.next() {
for cap in m.captures {
if cap.index != callee_idx {
continue;
}
let Ok(text) = cap.node.utf8_text(content_bytes) else {
continue;
};
if text != target {
continue;
}
let line = cap.node.start_position().row as u32 + 1;
let call_node = cap.node.parent().unwrap_or(cap.node);
let same_line = call_node.start_position().row == call_node.end_position().row;
let call_text: String = if same_line {
let row = call_node.start_position().row;
if row < lines.len() {
lines[row].trim().to_string()
} else {
text.to_string()
}
} else {
text.to_string()
};
let receiver = extract_receiver(cap.node, content_bytes);
let arg_count = extract_arg_count(call_node);
let (calling_function, caller_range) =
find_enclosing_function(cap.node, &lines, lang);
callers.push(CallerMatch {
path: path.to_path_buf(),
line,
calling_function,
call_text,
caller_range,
receiver,
arg_count,
content: Arc::clone(&shared_content),
});
}
}
callers
}) else {
return Vec::new();
};
callers
}
pub(crate) fn find_callers_batch(
targets: &HashSet<String>,
scope: &Path,
bloom: &crate::index::bloom::BloomFilterCache,
glob: Option<&str>,
cache: Option<&crate::cache::OutlineCache>,
early_quit: Option<usize>,
) -> Result<Vec<(String, CallerMatch)>, SrcwalkError> {
let matches: Mutex<Vec<(String, CallerMatch)>> = Mutex::new(Vec::new());
let found_count = AtomicUsize::new(0);
let target_vec: Vec<&str> = targets.iter().map(String::as_str).collect();
let ac = if target_vec.len() >= 3 {
aho_corasick::AhoCorasick::new(&target_vec).ok()
} else {
None
};
let mut sorted_targets: Vec<&str> = target_vec.clone();
sorted_targets.sort_by_key(|t| std::cmp::Reverse(t.len()));
let walker = crate::search::walker(scope, glob)?;
walker.run(|| {
let matches = &matches;
let found_count = &found_count;
let ac = ac.as_ref();
let sorted_targets = &sorted_targets;
Box::new(move |entry| {
if let Some(cap) = early_quit {
if found_count.load(Ordering::Relaxed) >= cap {
return ignore::WalkState::Quit;
}
}
let Ok(entry) = entry else {
return ignore::WalkState::Continue;
};
if !entry.file_type().is_some_and(|ft| ft.is_file()) {
return ignore::WalkState::Continue;
}
let path = entry.path();
let (file_len, mtime) = match std::fs::metadata(path) {
Ok(meta) => (
meta.len(),
meta.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH),
),
Err(_) => return ignore::WalkState::Continue,
};
if file_len > 500_000 {
return ignore::WalkState::Continue;
}
if crate::search::io::is_minified_filename(path) {
return ignore::WalkState::Continue;
}
let Some(bytes) = crate::search::read_file_bytes(path, file_len) else {
return ignore::WalkState::Continue;
};
let any_match = if let Some(ac) = ac {
ac.is_match(&*bytes)
} else {
sorted_targets
.iter()
.any(|t| memchr::memmem::find(&bytes, t.as_bytes()).is_some())
};
if !any_match {
return ignore::WalkState::Continue;
}
if file_len >= crate::search::io::MINIFIED_CHECK_THRESHOLD
&& crate::search::io::looks_minified(&bytes)
{
return ignore::WalkState::Continue;
}
let Ok(content) = std::str::from_utf8(&bytes) else {
return ignore::WalkState::Continue;
};
if !targets
.iter()
.any(|t| bloom.contains(path, mtime, content, t))
{
return ignore::WalkState::Continue;
}
let file_type = detect_file_type(path);
let FileType::Code(lang) = file_type else {
return ignore::WalkState::Continue;
};
let Some(ts_lang) = outline_language(lang) else {
return ignore::WalkState::Continue;
};
let file_callers =
find_callers_treesitter_batch(path, targets, &ts_lang, content, lang, mtime, cache);
if !file_callers.is_empty() {
found_count.fetch_add(file_callers.len(), Ordering::Relaxed);
let mut all = matches
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
all.extend(file_callers);
}
ignore::WalkState::Continue
})
});
Ok(matches
.into_inner()
.unwrap_or_else(std::sync::PoisonError::into_inner))
}
fn find_callers_treesitter_batch(
path: &Path,
targets: &HashSet<String>,
ts_lang: &tree_sitter::Language,
content: &str,
lang: crate::types::Lang,
mtime: std::time::SystemTime,
cache: Option<&crate::cache::OutlineCache>,
) -> Vec<(String, CallerMatch)> {
let Some(query_str) = crate::search::callees::callee_query_str(lang) else {
return Vec::new();
};
let tree = if let Some(c) = cache {
let Some(tree) = c.get_or_parse(path, mtime, content, ts_lang) else {
return Vec::new();
};
tree
} else {
let mut parser = tree_sitter::Parser::new();
if parser.set_language(ts_lang).is_err() {
return Vec::new();
}
let Some(tree) = parser.parse(content, None) else {
return Vec::new();
};
tree
};
let content_bytes = content.as_bytes();
let lines: Vec<&str> = content.lines().collect();
let shared_content: Arc<String> = Arc::new(content.to_string());
let Some(callers) = crate::search::callees::with_callee_query(ts_lang, query_str, |query| {
let Some(callee_idx) = query.capture_index_for_name("callee") else {
return Vec::new();
};
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), content_bytes);
let mut callers = Vec::new();
while let Some(m) = matches.next() {
for cap in m.captures {
if cap.index != callee_idx {
continue;
}
let Ok(text) = cap.node.utf8_text(content_bytes) else {
continue;
};
if !targets.contains(text) {
continue;
}
let matched_target = text.to_string();
let line = cap.node.start_position().row as u32 + 1;
let call_node = cap.node.parent().unwrap_or(cap.node);
let same_line = call_node.start_position().row == call_node.end_position().row;
let call_text: String = if same_line {
let row = call_node.start_position().row;
if row < lines.len() {
lines[row].trim().to_string()
} else {
matched_target.clone()
}
} else {
matched_target.clone()
};
let (calling_function, caller_range) =
find_enclosing_function(cap.node, &lines, lang);
let receiver = extract_receiver(cap.node, content_bytes);
let arg_count = extract_arg_count(call_node);
callers.push((
matched_target,
CallerMatch {
path: path.to_path_buf(),
line,
calling_function,
call_text,
caller_range,
receiver,
arg_count,
content: Arc::clone(&shared_content),
},
));
}
}
callers
}) else {
return Vec::new();
};
callers
}
fn extract_receiver(callee_node: tree_sitter::Node, source: &[u8]) -> Option<String> {
let parent = callee_node.parent()?;
let kind = parent.kind();
match kind {
"field_expression"
| "member_expression"
| "selector_expression"
| "attribute"
| "member_access_expression"
| "scoped_call_expression"
| "member_call_expression" => {
let obj = parent
.child_by_field_name("object")
.or_else(|| parent.child_by_field_name("receiver"))
.or_else(|| parent.child_by_field_name("expression"));
let obj = obj.or_else(|| {
(0..parent.named_child_count())
.filter_map(|i| parent.named_child(i))
.find(|c| c.id() != callee_node.id())
});
let text = obj?.utf8_text(source).ok()?;
Some(if text.len() > 40 {
format!("{}…", &text[..37])
} else {
text.to_string()
})
}
"method_invocation" => {
let text = parent
.child_by_field_name("object")?
.utf8_text(source)
.ok()?;
Some(if text.len() > 40 {
format!("{}…", &text[..37])
} else {
text.to_string()
})
}
"scoped_identifier" | "qualified_identifier" => {
let mut cursor = parent.walk();
let first = parent
.named_children(&mut cursor)
.find(|c| c.id() != callee_node.id())?;
Some(first.utf8_text(source).ok()?.to_string())
}
"call" => {
let text = parent
.child_by_field_name("receiver")?
.utf8_text(source)
.ok()?;
Some(if text.len() > 40 {
format!("{}…", &text[..37])
} else {
text.to_string()
})
}
"navigation_expression" => {
(0..parent.named_child_count())
.filter_map(|i| parent.named_child(i))
.find(|c| c.id() != callee_node.id())
.and_then(|obj| {
let text = obj.utf8_text(source).ok()?;
Some(if text.len() > 40 {
format!("{}…", &text[..37])
} else {
text.to_string()
})
})
}
"navigation_suffix" => {
let nav = parent.parent()?;
if nav.kind() != "navigation_expression" {
return None;
}
(0..nav.named_child_count())
.filter_map(|i| nav.named_child(i))
.find(|c| c.kind() != "navigation_suffix")
.and_then(|obj| {
let text = obj.utf8_text(source).ok()?;
Some(if text.len() > 40 {
format!("{}…", &text[..37])
} else {
text.to_string()
})
})
}
_ => None,
}
}
fn extract_arg_count(call_node: tree_sitter::Node) -> Option<u8> {
for node in [Some(call_node), call_node.parent()] {
let node = node?;
let mut cursor = node.walk();
for child in node.named_children(&mut cursor) {
match child.kind() {
"arguments" | "argument_list" | "actual_parameters" | "method_arguments"
| "value_arguments" | "call_suffix" => {
let mut arg_cursor = child.walk();
let count = child.named_children(&mut arg_cursor).count();
return Some(count.min(255) as u8);
}
_ => {}
}
}
}
None
}
const TYPE_KINDS: &[&str] = &[
"class_declaration",
"class_definition",
"struct_item",
"impl_item",
"interface_declaration",
"trait_item",
"trait_declaration",
"type_declaration",
"enum_item",
"enum_declaration",
"module",
"mod_item",
"namespace_definition",
];
fn find_enclosing_function(
node: tree_sitter::Node,
lines: &[&str],
lang: crate::types::Lang,
) -> (String, Option<(u32, u32)>) {
let mut current = Some(node);
while let Some(n) = current {
let kind = n.kind();
let def_name = if DEFINITION_KINDS.contains(&kind) {
extract_definition_name(n, lines)
} else if lang == crate::types::Lang::Elixir
&& crate::lang::treesitter::is_elixir_definition(n, lines)
{
crate::lang::treesitter::extract_elixir_definition_name(n, lines)
} else {
None
};
if let Some(name) = def_name {
let range = Some((
n.start_position().row as u32 + 1,
n.end_position().row as u32 + 1,
));
let mut parent = n.parent();
while let Some(p) = parent {
if TYPE_KINDS.contains(&p.kind()) {
if let Some(type_name) = extract_definition_name(p, lines) {
return (format!("{type_name}.{name}"), range);
}
}
if lang == crate::types::Lang::Elixir
&& crate::lang::treesitter::is_elixir_definition(p, lines)
{
if let Some(type_name) =
crate::lang::treesitter::extract_elixir_definition_name(p, lines)
{
return (format!("{type_name}.{name}"), range);
}
}
parent = p.parent();
}
return (name, range);
}
current = n.parent();
}
("<top-level>".to_string(), None)
}
pub fn search_callers_expanded(
target: &str,
scope: &Path,
cache: &OutlineCache,
_session: &Session,
bloom: &crate::index::bloom::BloomFilterCache,
expand: usize,
context: Option<&Path>,
limit: Option<usize>,
offset: usize,
glob: Option<&str>,
filter: Option<&str>,
count_by: Option<&str>,
) -> Result<String, SrcwalkError> {
let max_matches = limit.unwrap_or(usize::MAX);
let group_limit = limit.unwrap_or(50);
let mut callers = find_callers(target, scope, bloom, glob, Some(cache))?;
let filters = parse_callsite_filters(filter)?;
let unfiltered_total = callers.len();
if !filters.is_empty() {
callers.retain(|caller| filters.iter().all(|f| f.matches(caller, scope)));
}
if callers.is_empty() {
return Ok(format!(
"# Callers of \"{}\" in {} — no call sites found\n\n\
Tip: srcwalk detects only direct, by-name call sites. The symbol may still be invoked via:\n\
- Rust trait objects (`dyn Trait`) or generic bounds\n\
- Go interface dispatch or function values stored in structs\n\
- Java/Kotlin interface or abstract methods, reflection\n\
- TypeScript/JS class hierarchies, callbacks, or dynamic property access\n\
- Python duck typing, `getattr`, decorators\n\n\
Try `srcwalk(\"{}\")` (symbol search) to find the declaring interface/trait, \
then run `callers` on that name, or search for implementors.",
target,
scope.display(),
target,
));
}
if let Some(field) = count_by {
return format_callsite_counts(target, scope, &callers, field, filter, group_limit, offset);
}
let mut sorted_callers = callers;
rank_callers(&mut sorted_callers, scope, context);
let total = sorted_callers.len();
let all_caller_names: HashSet<String> = sorted_callers
.iter()
.filter(|c| c.calling_function != "<top-level>")
.map(|c| c.calling_function.clone())
.collect();
let effective_offset = offset.min(total);
if effective_offset > 0 {
sorted_callers.drain(..effective_offset);
}
sorted_callers.truncate(max_matches);
let shown = sorted_callers.len();
let mut output = format!(
"# Slice: {target} — {total} call site{}\n\n[symbol] {target}\n<- calls\n",
if total == 1 { "" } else { "s" }
);
for (i, caller) in sorted_callers.iter().enumerate() {
let _ = write!(
output,
" [fn] {} {}:{}",
caller.calling_function,
rel_nonempty(&caller.path, scope),
caller.line,
);
if let Some(ref recv) = caller.receiver {
let _ = write!(output, " recv={recv}");
}
if let Some(argc) = caller.arg_count {
let _ = write!(output, " args={argc}");
}
let _ = writeln!(output);
if i < expand {
if let Some((start, end)) = caller.caller_range {
let lines: Vec<&str> = caller.content.lines().collect();
let window_start = caller.line.saturating_sub(2).max(start);
let window_end = (caller.line + 2).min(end);
let start_idx = (window_start as usize).saturating_sub(1);
let end_idx = (window_end as usize).min(lines.len());
output.push_str("\n```\n");
for (idx, line) in lines[start_idx..end_idx].iter().enumerate() {
let line_num = start_idx + idx + 1;
let prefix = if line_num == caller.line as usize {
"► "
} else {
" "
};
let _ = writeln!(output, "{prefix}{line_num:4} │ {line}");
}
output.push_str("```\n");
}
}
}
let mut footer = String::new();
if total > effective_offset + shown {
let omitted = total - effective_offset - shown;
let next_offset = effective_offset + shown;
let page_size = shown.max(1);
let _ = write!(
footer,
"> Tip: {omitted} more call sites available. Continue with --offset {next_offset} --limit {page_size}."
);
} else if effective_offset > 0 {
let _ = write!(
footer,
"> Tip: end of results at offset {effective_offset}."
);
}
if !footer.is_empty() {
footer.push('\n');
}
footer.push_str("> Tip: drill into any call site with `srcwalk <path>:<line>`.");
if sorted_callers
.iter()
.any(|caller| caller.arg_count.is_some() || caller.receiver.is_some())
{
footer.push_str(
"\n> Tip: classify callsites with --count-by args or --filter 'args:N receiver:NAME'.",
);
}
if !filters.is_empty() {
let _ = write!(
footer,
"\n> Tip: filter matched {total}/{unfiltered_total} call sites. Qualifiers: args:N receiver:NAME caller:NAME path:TEXT text:TEXT."
);
}
if !all_caller_names.is_empty() && all_caller_names.len() <= IMPACT_FANOUT_THRESHOLD {
if let Ok(hop2) = find_callers_batch(
&all_caller_names,
scope,
bloom,
glob,
Some(cache),
Some(BATCH_EARLY_QUIT),
) {
let hop1_locations: HashSet<(PathBuf, u32)> = sorted_callers
.iter()
.map(|c| (c.path.clone(), c.line))
.collect();
let hop2_filtered: Vec<_> = hop2
.into_iter()
.filter(|(_, m)| !hop1_locations.contains(&(m.path.clone(), m.line)))
.collect();
if !hop2_filtered.is_empty() {
output.push_str("\n── impact (2nd hop) ──\n");
let mut seen: HashSet<(String, PathBuf)> = HashSet::new();
let mut count = 0;
for (via, m) in &hop2_filtered {
let key = (m.calling_function.clone(), m.path.clone());
if !seen.insert(key) {
continue;
}
if count >= IMPACT_MAX_RESULTS {
break;
}
let rel_path = rel_nonempty(&m.path, scope);
let _ = writeln!(
output,
" {:<20} {}:{} \u{2192} {}",
m.calling_function, rel_path, m.line, via
);
count += 1;
}
let unique_total = hop2_filtered
.iter()
.map(|(_, m)| (&m.calling_function, &m.path))
.collect::<HashSet<_>>()
.len();
if unique_total > IMPACT_MAX_RESULTS {
let _ = writeln!(
output,
" ... and {} more",
unique_total - IMPACT_MAX_RESULTS
);
if !footer.is_empty() {
footer.push('\n');
}
footer.push_str(
"> Tip: impact list was capped. Use --callers --depth 2 for the full 2-hop graph.",
);
}
let _ = writeln!(
output,
"\n{} functions affected across 2 hops.",
sorted_callers.len() + count
);
}
}
}
let tokens = crate::types::estimate_tokens(output.len() as u64);
let token_str = if tokens >= 1000 {
format!("~{}.{}k", tokens / 1000, (tokens % 1000) / 100)
} else {
format!("~{tokens}")
};
let _ = write!(output, "\n\n({token_str} tokens)");
if !footer.is_empty() {
let _ = write!(output, "\n\n{footer}");
}
Ok(output)
}
fn format_callsite_counts(
target: &str,
scope: &Path,
callers: &[CallerMatch],
field: &str,
filter: Option<&str>,
limit: usize,
offset: usize,
) -> Result<String, SrcwalkError> {
let field = normalize_count_field(field)?;
let mut counts: std::collections::BTreeMap<String, usize> = std::collections::BTreeMap::new();
for caller in callers {
let key = callsite_field_value(caller, scope, field);
*counts.entry(key).or_insert(0) += 1;
}
let total = callers.len();
let filter_suffix = filter.map_or(String::new(), |f| format!(" matching `{f}`"));
let mut output = format!(
"# Slice: {target} — {total} call site{} grouped by {field}{}\n\n[symbol] {target}\n<- calls\n",
if total == 1 { "" } else { "s" },
filter_suffix,
);
let mut rows: Vec<_> = counts.into_iter().collect();
rows.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
let total_groups = rows.len();
let effective_offset = offset.min(total_groups);
let page_size = limit.max(1);
for (key, count) in rows.into_iter().skip(effective_offset).take(page_size) {
let _ = writeln!(output, " [group] {field}={key} count={count}");
}
let shown_end = (effective_offset + page_size).min(total_groups);
let mut footer = String::from(
"> Tip: narrow with --filter 'args:N receiver:NAME caller:NAME path:TEXT text:TEXT'; group with --count-by args|caller|receiver|file.",
);
if total_groups > shown_end {
let omitted = total_groups - shown_end;
let _ = write!(
footer,
"\n> Tip: {omitted} more groups available. Continue with --offset {shown_end} --limit {page_size}."
);
} else if effective_offset > 0 {
let _ = write!(
footer,
"\n> Tip: end of groups at offset {effective_offset}."
);
}
let _ = write!(output, "\n{footer}");
Ok(output)
}
fn normalize_count_field(field: &str) -> Result<&'static str, SrcwalkError> {
match field {
"args" => Ok("args"),
"caller" => Ok("caller"),
"receiver" | "recv" => Ok("receiver"),
"path" => Ok("path"),
"file" => Ok("file"),
_ => Err(SrcwalkError::InvalidQuery {
query: field.to_string(),
reason: "unsupported count field; use args, caller, receiver, path, or file"
.to_string(),
}),
}
}
fn callsite_field_value(caller: &CallerMatch, scope: &Path, field: &str) -> String {
match field {
"args" => caller
.arg_count
.map_or_else(|| "?".to_string(), |argc| argc.to_string()),
"caller" => caller.calling_function.clone(),
"receiver" => caller
.receiver
.clone()
.unwrap_or_else(|| "<none>".to_string()),
"path" => rel_nonempty(&caller.path, scope),
"file" => caller
.path
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("<unknown>")
.to_string(),
_ => "<unknown>".to_string(),
}
}
#[derive(Debug, PartialEq, Eq)]
struct CallsiteFilter {
field: String,
value: String,
}
fn parse_callsite_filters(filter: Option<&str>) -> Result<Vec<CallsiteFilter>, SrcwalkError> {
let Some(filter) = filter else {
return Ok(Vec::new());
};
let mut filters = Vec::new();
for part in filter.split_whitespace() {
let Some((field, value)) = part.split_once(':') else {
return Err(SrcwalkError::InvalidQuery {
query: filter.to_string(),
reason: "filters must use field:value qualifiers".to_string(),
});
};
let field = field.trim().to_ascii_lowercase();
let value = value.trim().to_string();
if field.is_empty() || value.is_empty() {
return Err(SrcwalkError::InvalidQuery {
query: filter.to_string(),
reason: "filter field and value cannot be empty".to_string(),
});
}
match field.as_str() {
"args" | "receiver" | "recv" | "caller" | "path" | "file" | "text" => {
filters.push(CallsiteFilter { field, value });
}
_ => {
return Err(SrcwalkError::InvalidQuery {
query: filter.to_string(),
reason: format!(
"unsupported filter field `{field}`; use args, receiver, caller, path, or text"
),
});
}
}
}
Ok(filters)
}
impl CallsiteFilter {
fn matches(&self, caller: &CallerMatch, scope: &Path) -> bool {
match self.field.as_str() {
"args" => caller
.arg_count
.is_some_and(|argc| self.value.parse::<u8>().is_ok_and(|wanted| argc == wanted)),
"receiver" | "recv" => caller.receiver.as_deref() == Some(self.value.as_str()),
"caller" => caller.calling_function == self.value,
"path" | "file" => rel_nonempty(&caller.path, scope).contains(&self.value),
"text" => caller.call_text.contains(&self.value),
_ => false,
}
}
}
fn rank_callers(callers: &mut [CallerMatch], scope: &Path, context: Option<&Path>) {
callers.sort_by(|a, b| {
if let Some(ctx) = context {
match (a.path == ctx, b.path == ctx) {
(true, false) => return std::cmp::Ordering::Less,
(false, true) => return std::cmp::Ordering::Greater,
_ => {}
}
}
let a_rel = a.path.strip_prefix(scope).unwrap_or(&a.path);
let b_rel = b.path.strip_prefix(scope).unwrap_or(&b.path);
a_rel
.components()
.count()
.cmp(&b_rel.components().count())
.then_with(|| a.path.cmp(&b.path))
.then_with(|| a.line.cmp(&b.line))
});
}
#[cfg(test)]
mod callsite_filter_tests {
use super::*;
fn sample_match() -> CallerMatch {
CallerMatch {
path: PathBuf::from("/repo/src/main.rs"),
line: 42,
calling_function: "main".to_string(),
call_text: "client.start(1, monitor)".to_string(),
caller_range: Some((40, 50)),
receiver: Some("client".to_string()),
arg_count: Some(2),
content: Arc::new(String::new()),
}
}
#[test]
fn parse_callsite_filters_accepts_qualifiers() {
let filters = parse_callsite_filters(Some("args:2 receiver:client caller:main"))
.expect("valid filters");
assert_eq!(filters.len(), 3);
assert_eq!(filters[0].field, "args");
assert_eq!(filters[0].value, "2");
}
#[test]
fn parse_callsite_filters_rejects_unknown_fields() {
let err = parse_callsite_filters(Some("unknown:x")).expect_err("invalid field");
assert!(err.to_string().contains("unsupported filter field"));
}
#[test]
fn callsite_filters_match_semantic_fields() {
let caller = sample_match();
let scope = Path::new("/repo");
let filters = parse_callsite_filters(Some(
"args:2 receiver:client caller:main path:src text:start",
))
.expect("valid filters");
assert!(filters.iter().all(|f| f.matches(&caller, scope)));
}
#[test]
fn count_field_values_use_display_facts() {
let caller = sample_match();
let scope = Path::new("/repo");
assert_eq!(callsite_field_value(&caller, scope, "args"), "2");
assert_eq!(callsite_field_value(&caller, scope, "caller"), "main");
assert_eq!(callsite_field_value(&caller, scope, "receiver"), "client");
assert_eq!(callsite_field_value(&caller, scope, "path"), "src/main.rs");
assert_eq!(callsite_field_value(&caller, scope, "file"), "main.rs");
}
}