use std::path::{Path, PathBuf};
use ra_ap_syntax::ast::HasAttrs;
use ra_ap_syntax::ast::HasModuleItem;
use ra_ap_syntax::ast::HasName;
use ra_ap_syntax::{AstNode, AstToken, Edition, SyntaxKind, ast};
use crate::error::{Error, io_context};
#[derive(Debug, Clone)]
pub enum TargetSpec {
Fn(String),
File(PathBuf),
Mod(String),
}
#[derive(Debug, Clone)]
pub struct ResolvedTarget {
pub file: PathBuf,
pub functions: Vec<crate::naming::QualifiedFunction>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SkipReason {
Const,
ExternAbi,
}
impl std::fmt::Display for SkipReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SkipReason::Const => write!(f, "const"),
SkipReason::ExternAbi => write!(f, "extern"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum Classification {
Instrumentable,
Skip(SkipReason),
}
#[derive(Debug, Clone)]
pub struct SkippedFunction {
pub name: String,
pub reason: SkipReason,
pub path: PathBuf,
}
#[derive(Debug)]
pub struct ResolveResult {
pub targets: Vec<ResolvedTarget>,
pub skipped: Vec<SkippedFunction>,
pub all_functions: Vec<ResolvedTarget>,
}
pub fn resolve_targets(
src_dir: &Path,
specs: &[TargetSpec],
exact: bool,
) -> Result<ResolveResult, Error> {
let rs_files = walk_rs_files(src_dir)?;
let project_root = src_dir.parent().unwrap_or(src_dir);
let rel_path = |file: &Path| -> PathBuf {
file.strip_prefix(project_root)
.unwrap_or(file)
.to_path_buf()
};
let mut results: Vec<ResolvedTarget> = Vec::new();
let mut skipped: Vec<SkippedFunction> = Vec::new();
let mut all_seen_names: Vec<String> = Vec::new();
let mut all_functions: Vec<ResolvedTarget> = Vec::new();
let mut all_skipped: Vec<SkippedFunction> = Vec::new();
for file in &rs_files {
let source = std::fs::read_to_string(file).map_err(|source| Error::RunReadError {
path: file.clone(),
source,
})?;
let (all_fns, file_skipped) = extract_functions(&source, rel_path(file));
if !all_fns.is_empty() {
merge_into(&mut all_functions, file, all_fns);
}
all_skipped.extend(file_skipped);
}
if specs.is_empty() {
results = all_functions.clone();
skipped = all_skipped;
} else {
for spec in specs {
match spec {
TargetSpec::Fn(pattern) => {
for file in &rs_files {
let source = std::fs::read_to_string(file).map_err(|source| {
Error::RunReadError {
path: file.clone(),
source,
}
})?;
let (all_fns, file_skipped) = extract_functions(&source, rel_path(file));
all_seen_names.extend(all_fns.iter().map(|qf| qf.minimal.clone()));
let matched: Vec<crate::naming::QualifiedFunction> = all_fns
.into_iter()
.filter(|qf| {
let name = &qf.minimal;
let bare = name.rsplit("::").next().unwrap_or(name);
if exact {
bare == pattern.as_str() || name == pattern.as_str()
} else {
bare.contains(pattern.as_str())
|| name.contains(pattern.as_str())
}
})
.collect();
if !matched.is_empty() {
merge_into(&mut results, file, matched);
}
let matched_skipped: Vec<SkippedFunction> = file_skipped
.into_iter()
.filter(|s| {
let bare = s.name.rsplit("::").next().unwrap_or(&s.name);
if exact {
bare == pattern.as_str() || s.name == pattern.as_str()
} else {
bare.contains(pattern.as_str())
|| s.name.contains(pattern.as_str())
}
})
.collect();
skipped.extend(matched_skipped);
}
}
TargetSpec::File(file_path) => {
let matching_files: Vec<&PathBuf> =
rs_files.iter().filter(|f| f.ends_with(file_path)).collect();
for file in matching_files {
let source = std::fs::read_to_string(file).map_err(|source| {
Error::RunReadError {
path: file.clone(),
source,
}
})?;
let (all_fns, file_skipped) = extract_functions(&source, rel_path(file));
all_seen_names.extend(all_fns.iter().map(|qf| qf.minimal.clone()));
if !all_fns.is_empty() {
merge_into(&mut results, file, all_fns);
}
skipped.extend(file_skipped);
}
}
TargetSpec::Mod(module_name) => {
for file in &rs_files {
let is_mod_file = file
.parent()
.and_then(|p| p.file_name())
.is_some_and(|dir| dir == module_name.as_str());
let is_named_file = file
.file_stem()
.is_some_and(|stem| stem == module_name.as_str());
if !is_mod_file && !is_named_file {
continue;
}
let source = std::fs::read_to_string(file).map_err(|source| {
Error::RunReadError {
path: file.clone(),
source,
}
})?;
let (all_fns, file_skipped) = extract_functions(&source, rel_path(file));
all_seen_names.extend(all_fns.iter().map(|qf| qf.minimal.clone()));
if !all_fns.is_empty() {
merge_into(&mut results, file, all_fns);
}
skipped.extend(file_skipped);
}
}
}
}
if results.is_empty() {
let desc = specs
.iter()
.map(|s| match s {
TargetSpec::Fn(p) => format!("--fn {p}"),
TargetSpec::File(p) => format!("--file {}", p.display()),
TargetSpec::Mod(m) => format!("--mod {m}"),
})
.collect::<Vec<_>>()
.join(", ");
let hint = if skipped.is_empty() {
build_suggestion_hint(specs, &all_seen_names)
} else {
let reasons = skipped
.iter()
.map(|s| s.reason.to_string())
.collect::<std::collections::BTreeSet<_>>()
.into_iter()
.collect::<Vec<_>>()
.join(", ");
format!(
". All {} matched function(s) were skipped ({}) -- piano cannot instrument these",
skipped.len(),
reasons
)
};
return Err(Error::NoTargetsFound { specs: desc, hint });
}
}
results.sort_by(|a, b| a.file.cmp(&b.file));
for r in &mut results {
r.functions.sort_by(|a, b| a.full.cmp(&b.full));
r.functions.dedup_by(|a, b| a.full == b.full);
}
skipped.sort_by(|a, b| a.name.cmp(&b.name));
skipped.dedup_by(|a, b| a.name == b.name);
all_functions.sort_by(|a, b| a.file.cmp(&b.file));
for r in &mut all_functions {
r.functions.sort_by(|a, b| a.full.cmp(&b.full));
r.functions.dedup_by(|a, b| a.full == b.full);
}
Ok(ResolveResult {
targets: results,
skipped,
all_functions,
})
}
fn merge_into(
results: &mut Vec<ResolvedTarget>,
file: &Path,
functions: Vec<crate::naming::QualifiedFunction>,
) {
if let Some(existing) = results.iter_mut().find(|r| r.file == file) {
existing.functions.extend(functions);
} else {
results.push(ResolvedTarget {
file: file.to_path_buf(),
functions,
});
}
}
fn walk_rs_files(dir: &Path) -> Result<Vec<PathBuf>, Error> {
let mut files = Vec::new();
walk_rs_files_inner(dir, &mut files)?;
files.sort();
Ok(files)
}
fn walk_rs_files_inner(dir: &Path, out: &mut Vec<PathBuf>) -> Result<(), Error> {
let entries = std::fs::read_dir(dir).map_err(io_context("read directory", dir))?;
for entry in entries {
let entry = entry.map_err(io_context("read directory entry", dir))?;
let path = entry.path();
if path.is_dir() {
walk_rs_files_inner(&path, out)?;
} else if path.extension().is_some_and(|ext| ext == "rs") {
out.push(path);
}
}
Ok(())
}
pub(crate) fn extract_functions(
source: &str,
rel_path: PathBuf,
) -> (Vec<crate::naming::QualifiedFunction>, Vec<SkippedFunction>) {
let parse = ast::SourceFile::parse(source, Edition::Edition2024);
if !parse.errors().is_empty() {
eprintln!(
"warning: parse errors in {} ({} errors, continuing with recovered tree)",
rel_path.display(),
parse.errors().len()
);
}
let file = parse.tree();
let mut collector = FnCollector {
functions: Vec::new(),
skipped: Vec::new(),
path: rel_path,
scope: crate::naming::ScopeState::new(),
};
collector.walk_source_file(&file);
let (expansions, _defs, calls) =
crate::macro_expand::expand_fn_generating_macros(file.syntax());
for exp in &expansions {
let call = &calls[exp.call_idx];
let impl_prefix = {
use ra_ap_syntax::TextSize;
let offset = TextSize::from(call.byte_start as u32);
file.syntax()
.token_at_offset(offset)
.right_biased()
.and_then(|token| {
let mut node = token.parent();
while let Some(n) = node {
if let Some(imp) = ast::Impl::cast(n.clone()) {
let self_ty = imp.self_ty()?;
return Some(crate::naming::render_impl_name(
&self_ty,
imp.trait_().as_ref(),
));
}
if ast::Fn::can_cast(n.kind()) {
break;
}
node = n.parent();
}
None
})
};
for fn_name in &exp.fn_names {
let qualified = if let Some(ref prefix) = impl_prefix {
format!("{prefix}::{fn_name}")
} else {
fn_name.clone()
};
let minimal = collector.scope.render_minimal(&qualified);
let medium = collector.scope.render_medium(&qualified);
let full = collector.scope.render_full(&qualified);
collector
.functions
.push(crate::naming::QualifiedFunction::new(
&minimal, &medium, &full,
));
}
}
(collector.functions, collector.skipped)
}
fn has_attr_cst(attrs: impl Iterator<Item = ast::Attr>, name: &str) -> bool {
attrs.into_iter().any(|a| {
a.path()
.and_then(|p| p.segment())
.and_then(|seg| seg.name_ref())
.is_some_and(|n| n.text() == name)
})
}
fn has_cfg_test_cst(attrs: impl Iterator<Item = ast::Attr>) -> bool {
attrs.into_iter().any(|a| {
let path_is_cfg = a
.path()
.and_then(|p| p.segment())
.and_then(|seg| seg.name_ref())
.is_some_and(|n| n.text() == "cfg");
if !path_is_cfg {
return false;
}
a.token_tree().is_some_and(|tt| {
let text = tt.syntax().text().to_string();
let inner = text.trim_start_matches('(').trim_end_matches(')').trim();
inner == "test"
})
})
}
fn extract_cst_abi(func: &ast::Fn) -> Option<String> {
let abi = func.abi()?;
match abi.abi_string() {
Some(abi_str) => {
let raw = abi_str.text();
let unquoted = raw
.strip_prefix('"')
.and_then(|s| s.strip_suffix('"'))
.unwrap_or(raw);
Some(unquoted.to_string())
}
None => {
Some(String::new())
}
}
}
pub(crate) fn classify(is_const: bool, abi: Option<&str>) -> Classification {
if is_const {
return Classification::Skip(SkipReason::Const);
}
if let Some(abi_str) = abi {
if abi_str != "Rust" {
return Classification::Skip(SkipReason::ExternAbi);
}
}
Classification::Instrumentable
}
fn classify_cst_fn(func: &ast::Fn) -> Classification {
let cst_abi = extract_cst_abi(func);
classify(func.const_token().is_some(), cst_abi.as_deref())
}
struct FnCollector {
functions: Vec<crate::naming::QualifiedFunction>,
skipped: Vec<SkippedFunction>,
path: PathBuf,
scope: crate::naming::ScopeState,
}
impl FnCollector {
fn walk_source_file(&mut self, file: &ast::SourceFile) {
for item in file.items() {
self.walk_item(&item);
}
}
fn walk_item(&mut self, item: &ast::Item) {
match item {
ast::Item::Module(module) => self.visit_module(module),
ast::Item::Fn(func) => self.visit_top_level_fn(func),
ast::Item::Impl(imp) => self.visit_impl(imp),
ast::Item::Trait(tr) => self.visit_trait(tr),
_ => {}
}
}
fn walk_item_list(&mut self, item_list: &ast::ItemList) {
for item in item_list.items() {
self.walk_item(&item);
}
}
fn visit_module(&mut self, module: &ast::Module) {
if has_cfg_test_cst(module.attrs()) {
return; }
let mod_name = module
.name()
.map(|n| n.text().to_string())
.unwrap_or_else(|| "_".to_string());
self.scope.push_mod(&mod_name);
if let Some(item_list) = module.item_list() {
self.walk_item_list(&item_list);
}
self.scope.pop();
}
fn visit_top_level_fn(&mut self, func: &ast::Fn) {
if !has_attr_cst(func.attrs(), "test") {
let name = func
.name()
.map(|n| n.text().to_string())
.unwrap_or_default();
self.record_function(func, &name);
}
let fn_name = func
.name()
.map(|n| n.text().to_string())
.unwrap_or_default();
self.scope.push_fn(&fn_name);
if let Some(body) = func.body() {
self.walk_fn_body(&body);
}
self.scope.pop();
}
fn visit_impl(&mut self, imp: &ast::Impl) {
if let Some(assoc_list) = imp.assoc_item_list() {
for assoc in assoc_list.assoc_items() {
if let ast::AssocItem::Fn(func) = assoc {
self.visit_impl_fn(&func);
}
}
}
}
fn visit_impl_fn(&mut self, func: &ast::Fn) {
if !has_attr_cst(func.attrs(), "test") {
let qualified = crate::naming::qualified_name_for_fn(func);
self.record_function(func, &qualified);
}
let fn_name = func
.name()
.map(|n| n.text().to_string())
.unwrap_or_default();
self.scope.push_fn(&fn_name);
if let Some(body) = func.body() {
self.walk_fn_body(&body);
}
self.scope.pop();
}
fn visit_trait(&mut self, tr: &ast::Trait) {
if let Some(assoc_list) = tr.assoc_item_list() {
for assoc in assoc_list.assoc_items() {
if let ast::AssocItem::Fn(func) = assoc {
self.visit_trait_fn(&func);
}
}
}
}
fn visit_trait_fn(&mut self, func: &ast::Fn) {
if func.body().is_some() {
let qualified = crate::naming::qualified_name_for_fn(func);
self.record_function(func, &qualified);
let fn_name = func
.name()
.map(|n| n.text().to_string())
.unwrap_or_default();
self.scope.push_fn(&fn_name);
if let Some(body) = func.body() {
self.walk_fn_body(&body);
}
self.scope.pop();
}
}
fn walk_fn_body(&mut self, body: &ast::BlockExpr) {
for node in body.syntax().descendants() {
if node.kind() == SyntaxKind::FN {
if node.text_range() == body.syntax().text_range() {
continue;
}
if let Some(inner_fn) = ast::Fn::cast(node.clone()) {
let qualified = crate::naming::qualified_name_for_fn(&inner_fn);
self.visit_nested_fn(&inner_fn, &qualified);
}
}
}
}
fn visit_nested_fn(&mut self, func: &ast::Fn, qualified: &str) {
if !has_attr_cst(func.attrs(), "test") {
self.record_function(func, qualified);
}
}
fn record_function(&mut self, func: &ast::Fn, qualified: &str) {
let minimal = self.scope.render_minimal(qualified);
match classify_cst_fn(func) {
Classification::Skip(reason) => {
self.skipped.push(SkippedFunction {
name: minimal,
reason,
path: self.path.clone(),
});
}
Classification::Instrumentable => {
let medium = self.scope.render_medium(qualified);
let full = self.scope.render_full(qualified);
self.functions.push(crate::naming::QualifiedFunction::new(
&minimal, &medium, &full,
));
}
}
}
}
fn levenshtein(a: &str, b: &str) -> usize {
let b_len = b.len();
let mut row: Vec<usize> = (0..=b_len).collect();
for (i, a_ch) in a.chars().enumerate() {
let mut prev = i;
row[0] = i + 1;
for (j, b_ch) in b.chars().enumerate() {
let cost = if a_ch == b_ch { prev } else { prev + 1 };
prev = row[j + 1];
row[j + 1] = cost.min(row[j] + 1).min(prev + 1);
}
}
row[b_len]
}
fn build_suggestion_hint(specs: &[TargetSpec], seen_names: &[String]) -> String {
let fn_patterns: Vec<&str> = specs
.iter()
.filter_map(|s| match s {
TargetSpec::Fn(p) => Some(p.as_str()),
_ => None,
})
.collect();
if fn_patterns.is_empty() {
return String::new();
}
let mut all_names: Vec<&String> = seen_names.iter().collect();
all_names.sort();
all_names.dedup();
let total_count = all_names.len();
let mut suggestions: Vec<String> = Vec::new();
for pattern in &fn_patterns {
let threshold = pattern.len() / 3;
let mut scored: Vec<(usize, &str)> = all_names
.iter()
.filter_map(|name| {
let bare = name.rsplit("::").next().unwrap_or(name);
let dist = levenshtein(pattern, bare).min(levenshtein(pattern, name));
if dist <= threshold && dist > 0 {
Some((dist, name.as_str()))
} else {
None
}
})
.collect();
scored.sort_by_key(|(d, _)| *d);
suggestions.extend(scored.iter().take(5).map(|(_, name)| (*name).to_owned()));
}
suggestions.sort();
suggestions.dedup();
if !suggestions.is_empty() {
format!(". Did you mean: {}?", suggestions.join(", "))
} else {
format!(". Found {total_count} functions, none matched. Run without --fn to instrument all")
}
}
pub fn module_prefix(relative: &Path) -> String {
let stem = relative.file_stem().and_then(|s| s.to_str()).unwrap_or("_");
let parent_components: Vec<&str> = relative
.parent()
.map(|p| {
p.components()
.map(|c| c.as_os_str().to_str().unwrap_or("_"))
.collect()
})
.unwrap_or_default();
if stem == "main" || stem == "lib" || stem == "mod" {
parent_components.join("::")
} else {
let mut parts = parent_components;
parts.push(stem);
parts.join("::")
}
}
pub fn qualify(prefix: &str, name: &str) -> String {
if prefix.is_empty() {
name.to_string()
} else {
format!("{prefix}::{name}")
}
}
#[cfg(test)]
mod tests {
use std::fs;
use tempfile::TempDir;
use super::*;
fn create_test_project(dir: &Path) {
let src = dir.join("src");
fs::create_dir_all(src.join("walker")).unwrap();
fs::write(src.join("main.rs"), "fn main() { walk(); }\nfn walk() {}\n").unwrap();
fs::write(
src.join("resolver.rs"),
"\
struct Resolver;
impl Resolver {
pub fn resolve(&self) -> bool { true }
fn internal_resolve(&self) {}
}
fn helper() {}
",
)
.unwrap();
fs::write(
src.join("walker").join("mod.rs"),
"pub fn walk_dir() {}\nfn scan() {}\n",
)
.unwrap();
fs::write(
src.join("special_fns.rs"),
"\
const fn fixed_size() -> usize { 42 }
unsafe fn dangerous() -> i32 { 0 }
extern \"C\" fn ffi_callback() {}
fn normal_fn() {}
struct Widget;
impl Widget {
const fn none() -> Option<Self> { None }
unsafe fn raw_ptr(&self) -> *const u8 { std::ptr::null() }
fn valid_method(&self) {}
}
trait Processor {
fn process(&self);
fn default_method(&self) { }
unsafe fn unsafe_default(&self) { }
fn required_method(&self);
}
",
)
.unwrap();
fs::write(
src.join("with_tests.rs"),
"\
fn production_fn() {}
#[test]
fn test_something() {}
#[cfg(test)]
mod tests {
fn test_helper() {}
#[test]
fn it_works() {}
}
",
)
.unwrap();
}
#[test]
fn resolve_fn_by_substring() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("walk".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(all_fns.contains(&"walk"), "should match exact 'walk'");
assert!(
all_fns.contains(&"walk_dir"),
"should match 'walk_dir' (substring)"
);
assert!(!all_fns.contains(&"helper"), "should not match 'helper'");
assert!(!all_fns.contains(&"scan"), "should not match 'scan'");
}
#[test]
fn resolve_fn_finds_impl_methods() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("resolve".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
all_fns.contains(&"Resolver::resolve"),
"should match impl method 'resolve'"
);
assert!(
all_fns.contains(&"Resolver::internal_resolve"),
"should match impl method 'internal_resolve'"
);
}
#[test]
fn resolve_file_gets_all_functions() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::File("resolver.rs".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
assert_eq!(result.targets.len(), 1);
let fns: Vec<&str> = result.targets[0]
.functions
.iter()
.map(|qf| qf.minimal.as_str())
.collect();
assert!(fns.contains(&"helper"));
assert!(fns.contains(&"Resolver::internal_resolve"));
assert!(fns.contains(&"Resolver::resolve"));
}
#[test]
fn resolve_mod_gets_directory_module() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Mod("walker".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
assert_eq!(result.targets.len(), 1);
let fns: Vec<&str> = result.targets[0]
.functions
.iter()
.map(|qf| qf.minimal.as_str())
.collect();
assert!(fns.contains(&"walk_dir"));
assert!(fns.contains(&"scan"));
}
#[test]
fn no_match_returns_error() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("nonexistent_xyz".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false);
assert!(result.is_err(), "should error when no functions match");
let err = result.unwrap_err().to_string();
assert!(
err.contains("nonexistent_xyz"),
"error should mention the pattern: {err}"
);
assert!(
err.contains("Found 14 functions"),
"error should show function count: {err}"
);
assert!(
err.contains("Run without --fn"),
"error should suggest running without --fn: {err}"
);
}
#[test]
fn resolve_skips_test_functions_and_cfg_test_modules() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::File("with_tests.rs".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
all_fns.contains(&"production_fn"),
"should include production function"
);
assert!(
!all_fns.contains(&"test_something"),
"should skip #[test] function"
);
assert!(
!all_fns.contains(&"test_helper"),
"should skip function inside #[cfg(test)] module"
);
assert!(
!all_fns.contains(&"it_works"),
"should skip #[test] inside #[cfg(test)] module"
);
}
#[test]
fn resolve_skips_unparseable_files() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let src = tmp.path().join("src");
fs::write(
src.join("template.tera.rs"),
"{% for variant in variants %}\nfn {{ variant }}() {}\n{% endfor %}\n",
)
.unwrap();
let specs = [TargetSpec::Fn("walk".into())];
let result = resolve_targets(&src, &specs, false).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
all_fns.contains(&"walk"),
"should still find valid functions"
);
assert!(
all_fns.contains(&"walk_dir"),
"should still find valid functions"
);
}
#[test]
fn resolve_empty_specs_returns_all_functions() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs: Vec<TargetSpec> = vec![];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(all_fns.contains(&"main"), "should include main");
assert!(
all_fns.contains(&"walk"),
"should include walk from main.rs"
);
assert!(
all_fns.contains(&"helper"),
"should include helper from resolver.rs"
);
assert!(
all_fns.contains(&"Resolver::resolve"),
"should include impl methods"
);
assert!(
all_fns.contains(&"walk_dir"),
"should include walk_dir from walker"
);
assert!(all_fns.contains(&"scan"), "should include scan from walker");
assert!(
all_fns.contains(&"production_fn"),
"should include production_fn"
);
assert!(!all_fns.contains(&"test_something"), "should skip #[test]");
assert!(
!all_fns.contains(&"it_works"),
"should skip test in cfg(test)"
);
}
#[test]
fn resolve_skips_const_and_extern_functions() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::File("special_fns.rs".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(all_fns.contains(&"normal_fn"), "should include normal_fn");
assert!(
all_fns.contains(&"Widget::valid_method"),
"should include Widget::valid_method"
);
assert!(all_fns.contains(&"dangerous"), "should include unsafe fn");
assert!(
all_fns.contains(&"Widget::raw_ptr"),
"should include unsafe impl method"
);
assert!(!all_fns.contains(&"fixed_size"), "should skip const fn");
assert!(!all_fns.contains(&"ffi_callback"), "should skip extern fn");
assert!(
!all_fns.contains(&"Widget::none"),
"should skip const impl method"
);
}
#[test]
fn resolve_instruments_unsafe_trait_default_methods() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::File("special_fns.rs".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
all_fns.contains(&"Processor::default_method"),
"should include safe trait default method"
);
assert!(
all_fns.contains(&"Processor::unsafe_default"),
"should include unsafe trait default method"
);
assert!(
!all_fns.contains(&"Processor::process"),
"should not include trait method without default body"
);
assert!(
!all_fns.contains(&"Processor::required_method"),
"should not include trait method without default body"
);
let skipped_names: Vec<&str> = result.skipped.iter().map(|s| s.name.as_str()).collect();
assert!(
!skipped_names.contains(&"Processor::unsafe_default"),
"unsafe trait default method should not be skipped"
);
}
#[test]
fn no_match_error_includes_suggestions_for_typo() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("heper".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("helper"),
"error should suggest 'helper' for typo 'heper': {err}"
);
assert!(
err.contains("Did you mean"),
"error should include 'Did you mean' phrasing: {err}"
);
}
#[test]
fn levenshtein_basic_cases() {
assert_eq!(levenshtein("", ""), 0);
assert_eq!(levenshtein("abc", "abc"), 0);
assert_eq!(levenshtein("abc", ""), 3);
assert_eq!(levenshtein("", "abc"), 3);
assert_eq!(levenshtein("kitten", "sitting"), 3);
assert_eq!(levenshtein("parse", "prase"), 2); assert_eq!(levenshtein("walk", "wlak"), 2);
assert_eq!(levenshtein("a", "b"), 1);
}
#[test]
fn no_match_error_shows_count_for_large_project() {
let tmp = TempDir::new().unwrap();
let src = tmp.path().join("src");
fs::create_dir_all(&src).unwrap();
let mut code = String::new();
for i in 0..20 {
code.push_str(&format!("fn func_{i}() {{}}\n"));
}
fs::write(src.join("many.rs"), &code).unwrap();
let specs = [TargetSpec::Fn("nonexistent".into())];
let result = resolve_targets(&src, &specs, false);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Found 20 functions"),
"error should show function count: {err}"
);
assert!(
err.contains("Run without --fn"),
"error should suggest running without --fn: {err}"
);
}
#[test]
fn no_match_error_shows_clean_patterns() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("zzz_nonexistent".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false);
let err = result.unwrap_err().to_string();
assert!(
err.starts_with("no functions matched"),
"error should start with 'no functions matched': {err}"
);
assert!(
err.contains("--fn zzz_nonexistent"),
"error should include the spec: {err}"
);
}
#[test]
fn resolve_fn_substring_matches_qualified_name() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("Resolver".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
all_fns.contains(&"Resolver::resolve"),
"should match 'Resolver::resolve' via qualified name substring"
);
assert!(
all_fns.contains(&"Resolver::internal_resolve"),
"should match 'Resolver::internal_resolve' via qualified name substring"
);
assert!(
!all_fns.contains(&"helper"),
"should not match unrelated 'helper'"
);
}
#[test]
fn resolve_fn_exact_match() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("walk".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, true).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(all_fns.contains(&"walk"), "should match exact 'walk'");
assert!(
!all_fns.contains(&"walk_dir"),
"should NOT match 'walk_dir' in exact mode"
);
}
#[test]
fn resolve_fn_exact_match_qualified() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("Resolver::resolve".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, true).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
all_fns.contains(&"Resolver::resolve"),
"should match qualified 'Resolver::resolve'"
);
assert!(
!all_fns.contains(&"Resolver::internal_resolve"),
"should NOT match 'Resolver::internal_resolve' in exact mode"
);
}
#[test]
fn resolve_fn_exact_no_match_shows_error() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("wal".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, true);
assert!(result.is_err(), "partial match should fail in exact mode");
let err = result.unwrap_err().to_string();
assert!(
err.contains("no functions matched"),
"error should say no functions matched: {err}"
);
}
#[test]
fn resolve_finds_unsafe_fn_by_pattern() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::Fn("dangerous".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
all_fns.contains(&"dangerous"),
"should find and instrument unsafe fn 'dangerous': {all_fns:?}"
);
}
#[test]
fn resolve_reports_skipped_functions_with_reasons() {
let tmp = TempDir::new().unwrap();
create_test_project(tmp.path());
let specs = [TargetSpec::File("special_fns.rs".into())];
let result = resolve_targets(&tmp.path().join("src"), &specs, false).unwrap();
let skipped_names: Vec<(&str, &SkipReason)> = result
.skipped
.iter()
.map(|s| (s.name.as_str(), &s.reason))
.collect();
assert!(
skipped_names.contains(&("fixed_size", &SkipReason::Const)),
"should report const fn as skipped: {skipped_names:?}"
);
assert!(
skipped_names.contains(&("ffi_callback", &SkipReason::ExternAbi)),
"should report extern fn as skipped: {skipped_names:?}"
);
assert!(
skipped_names.contains(&("Widget::none", &SkipReason::Const)),
"should report const impl method as skipped: {skipped_names:?}"
);
assert!(
!skipped_names.iter().any(|(name, _)| *name == "dangerous"),
"unsafe fn should not be skipped: {skipped_names:?}"
);
assert!(
!skipped_names
.iter()
.any(|(name, _)| *name == "Widget::raw_ptr"),
"unsafe impl method should not be skipped: {skipped_names:?}"
);
for s in &result.skipped {
assert_eq!(
s.path,
Path::new("src/special_fns.rs"),
"skipped fn '{}' should have relative path src/special_fns.rs",
s.name
);
}
let all_fns: Vec<&str> = result
.targets
.iter()
.flat_map(|r| r.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(all_fns.contains(&"normal_fn"));
assert!(all_fns.contains(&"Widget::valid_method"));
assert!(all_fns.contains(&"dangerous"));
assert!(all_fns.contains(&"Widget::raw_ptr"));
}
#[test]
fn hint_empty_when_only_file_and_mod_specs() {
let specs = [
TargetSpec::File("lib.rs".into()),
TargetSpec::Mod("mymod".into()),
];
let hint = build_suggestion_hint(&specs, &[]);
assert!(
hint.is_empty(),
"hint should be empty when no --fn specs are present, got: {hint:?}"
);
}
#[test]
fn module_prefix_crate_roots() {
assert_eq!(module_prefix(Path::new("main.rs")), "");
assert_eq!(module_prefix(Path::new("lib.rs")), "");
}
#[test]
fn module_prefix_simple_files() {
assert_eq!(module_prefix(Path::new("db.rs")), "db");
assert_eq!(module_prefix(Path::new("utils.rs")), "utils");
}
#[test]
fn module_prefix_mod_rs() {
assert_eq!(module_prefix(Path::new("db/mod.rs")), "db");
assert_eq!(
module_prefix(Path::new("api/handlers/mod.rs")),
"api::handlers"
);
}
#[test]
fn module_prefix_nested_files() {
assert_eq!(module_prefix(Path::new("db/query.rs")), "db::query");
assert_eq!(
module_prefix(Path::new("api/handlers/user.rs")),
"api::handlers::user"
);
}
#[cfg(unix)]
#[test]
fn module_prefix_non_utf8_fallback() {
use std::ffi::OsStr;
use std::os::unix::ffi::OsStrExt;
let bad_dir = OsStr::from_bytes(b"\x80");
let mut path = std::path::PathBuf::from(bad_dir);
path.push("query.rs");
assert_eq!(module_prefix(&path), "_::query");
let bad_stem = OsStr::from_bytes(b"\x80.rs");
let path2 = Path::new("api").join(bad_stem);
assert_eq!(module_prefix(&path2), "api::_");
}
#[test]
fn qualify_with_empty_prefix() {
assert_eq!(qualify("", "walk"), "walk");
assert_eq!(qualify("", "Walker::walk"), "Walker::walk");
}
#[test]
fn qualify_with_module_prefix() {
assert_eq!(qualify("db", "execute"), "db::execute");
assert_eq!(
qualify("db::query", "validate_input"),
"db::query::validate_input"
);
assert_eq!(
qualify("api", "Handler::validate_input"),
"api::Handler::validate_input"
);
}
#[test]
fn trait_impl_methods_get_disambiguated_names() {
let dir = TempDir::new().unwrap();
let src = dir.path().join("src");
fs::create_dir_all(&src).unwrap();
fs::write(
src.join("main.rs"),
r#"
use std::fmt;
struct Point;
impl fmt::Display for Point {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "point")
}
}
impl fmt::Debug for Point {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "point")
}
}
fn main() {}
"#,
)
.unwrap();
let result = resolve_targets(&src, &[], false).unwrap();
let names: Vec<&str> = result
.targets
.iter()
.flat_map(|t| t.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
names.contains(&"<Point as Display>::fmt"),
"should have Display-qualified name: {names:?}"
);
assert!(
names.contains(&"<Point as Debug>::fmt"),
"should have Debug-qualified name: {names:?}"
);
let fmt_count = names.iter().filter(|n| n.ends_with("fmt")).count();
assert_eq!(
fmt_count, 2,
"should have 2 distinct fmt entries: {names:?}"
);
}
#[test]
fn classify_returns_explicit_classification() {
assert!(matches!(
classify(false, None),
Classification::Instrumentable
));
assert!(matches!(
classify(true, None),
Classification::Skip(SkipReason::Const)
));
assert!(matches!(
classify(false, Some("C")),
Classification::Skip(SkipReason::ExternAbi)
));
assert!(matches!(
classify(false, Some("")),
Classification::Skip(SkipReason::ExternAbi)
));
assert!(matches!(
classify(false, Some("Rust")),
Classification::Instrumentable
));
assert!(matches!(
classify(false, None),
Classification::Instrumentable
));
}
#[test]
fn classify_cst_fn_delegates_to_classify() {
fn parse_fn(sig: &str) -> ast::Fn {
let source = format!("{sig} {{}}");
let parse = ast::SourceFile::parse(&source, Edition::Edition2024);
let file = parse.tree();
file.syntax()
.descendants()
.find_map(ast::Fn::cast)
.expect("should parse fn")
}
let func = parse_fn("fn foo()");
assert!(matches!(
classify_cst_fn(&func),
Classification::Instrumentable
));
let func = parse_fn("const fn foo()");
assert!(matches!(
classify_cst_fn(&func),
Classification::Skip(SkipReason::Const)
));
let func = parse_fn("unsafe fn foo()");
assert!(matches!(
classify_cst_fn(&func),
Classification::Instrumentable
));
let func = parse_fn("extern \"C\" fn foo()");
assert!(matches!(
classify_cst_fn(&func),
Classification::Skip(SkipReason::ExternAbi)
));
let func = parse_fn("extern fn foo()");
assert!(matches!(
classify_cst_fn(&func),
Classification::Skip(SkipReason::ExternAbi)
));
let func = parse_fn("extern \"Rust\" fn foo()");
assert!(matches!(
classify_cst_fn(&func),
Classification::Instrumentable
));
let func = parse_fn("async fn foo()");
assert!(matches!(
classify_cst_fn(&func),
Classification::Instrumentable
));
let func = parse_fn("fn foo<T: Clone>(x: T) -> T");
assert!(matches!(
classify_cst_fn(&func),
Classification::Instrumentable
));
let func = parse_fn("fn foo() -> impl std::future::Future<Output = i32>");
assert!(matches!(
classify_cst_fn(&func),
Classification::Instrumentable
));
}
#[test]
fn all_functions_includes_unselected() {
let dir = TempDir::new().unwrap();
create_test_project(dir.path());
let src = dir.path().join("src");
let specs = vec![TargetSpec::Fn("walk".to_string())];
let result = resolve_targets(&src, &specs, false).unwrap();
let measured_names: Vec<&str> = result
.targets
.iter()
.flat_map(|t| t.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
measured_names.iter().all(|n| n.contains("walk")),
"targets should only contain walk-matching functions. Got: {measured_names:?}"
);
let all_names: Vec<&str> = result
.all_functions
.iter()
.flat_map(|t| t.functions.iter().map(|qf| qf.minimal.as_str()))
.collect();
assert!(
all_names.len() >= measured_names.len(),
"all_functions should be >= targets. all={}, targets={}",
all_names.len(),
measured_names.len()
);
let has_non_walk = all_names.iter().any(|n| !n.contains("walk"));
assert!(
has_non_walk,
"all_functions should include non-walk functions for pass-through. Got: {all_names:?}"
);
}
#[test]
fn all_functions_equals_targets_when_no_specs() {
let dir = TempDir::new().unwrap();
create_test_project(dir.path());
let src = dir.path().join("src");
let specs: Vec<TargetSpec> = Vec::new();
let result = resolve_targets(&src, &specs, false).unwrap();
let target_names: std::collections::BTreeSet<String> = result
.targets
.iter()
.flat_map(|t| t.functions.iter().map(|qf| qf.minimal.clone()))
.collect();
let all_names: std::collections::BTreeSet<String> = result
.all_functions
.iter()
.flat_map(|t| t.functions.iter().map(|qf| qf.minimal.clone()))
.collect();
assert_eq!(
target_names, all_names,
"when no selectors, all_functions should equal targets"
);
}
#[test]
fn macro_in_impl_naming_diagnostic() {
let source = r#"
macro_rules! add_method {
($name:ident) => {
fn $name(&self) -> u32 { 42 }
};
}
struct S;
impl S {
add_method!(get_value);
fn regular_method(&self) -> u32 { 99 }
}
fn main() {}
"#;
let (functions, _) = extract_functions(source, PathBuf::from("test.rs"));
for f in &functions {
eprintln!(
" minimal={:?} medium={:?} full={:?}",
f.minimal, f.medium, f.full
);
}
let macro_fn = functions.iter().find(|f| f.minimal.contains("get_value"));
let regular_fn = functions
.iter()
.find(|f| f.minimal.contains("regular_method"));
assert!(
macro_fn.is_some(),
"macro-generated method should be discovered"
);
assert!(regular_fn.is_some(), "regular method should be discovered");
assert_eq!(
macro_fn.unwrap().minimal,
"S::get_value",
"macro-generated method should be impl-qualified"
);
assert_eq!(
regular_fn.unwrap().minimal,
"S::regular_method",
"regular method should be impl-qualified"
);
}
}