use syn::visit::Visit;
#[derive(Clone, Debug, PartialEq)]
pub struct DetectedCall {
pub function_name: String,
pub call_type: CallType,
}
#[derive(Clone, Debug, PartialEq)]
pub enum CallType {
FreeFunction,
StaticMethod { type_name: String },
MethodCall { receiver: Option<String> },
}
const FILTERED_NAMES: &[&str] = &[
"Ok",
"Some",
"Err",
"None",
"vec",
"format",
"println",
"eprintln",
"print",
"eprint",
"panic",
"todo",
"unimplemented",
"unreachable",
"assert",
"assert_eq",
"assert_ne",
"debug_assert",
"debug_assert_eq",
"debug_assert_ne",
"write",
"writeln",
"log",
"cfg",
"include",
"include_str",
"include_bytes",
"env",
"option_env",
"concat",
"stringify",
"file",
"line",
"column",
"module_path",
"Box",
"Vec",
"String",
"Arc",
"Rc",
"Mutex",
"RefCell",
];
const FILTERED_METHOD_NAMES: &[&str] = &[
"unwrap",
"expect",
"clone",
"to_string",
"to_owned",
"iter",
"into_iter",
"map",
"filter",
"collect",
"fold",
"for_each",
"find",
"any",
"all",
"push",
"pop",
"len",
"is_empty",
"contains",
"get",
"insert",
"remove",
"extend",
"ok_or",
"ok_or_else",
"map_err",
"and_then",
"or_else",
"unwrap_or",
"unwrap_or_else",
"as_ref",
"as_mut",
"borrow",
"borrow_mut",
"into",
"from",
"try_into",
"try_from",
"default",
"to_vec",
"as_slice",
"as_str",
"change_context",
"attach_printable",
"is_some",
"is_none",
"is_ok",
"is_err",
"trim",
"trim_start",
"trim_end",
"split",
"join",
"replace",
"starts_with",
"ends_with",
"lines",
"chars",
"bytes",
"next",
"enumerate",
"skip",
"take",
"zip",
"chain",
"flat_map",
"flatten",
"filter_map",
"position",
"count",
"sort",
"sort_by",
"sort_by_key",
"dedup",
];
pub fn detect_function_calls(function_source: &str) -> Result<Vec<DetectedCall>, String> {
let item_fn = syn::parse_str::<syn::ItemFn>(function_source).or_else(|_| {
let wrapped = format!("fn __wrapper() {{ {} }}", function_source);
syn::parse_str::<syn::ItemFn>(&wrapped)
});
let item_fn = match item_fn {
Ok(f) => f,
Err(e) => return Err(format!("syn parse error: {}", e)),
};
let mut visitor = CallVisitor { calls: Vec::new() };
visitor.visit_item_fn(&item_fn);
let mut seen = std::collections::HashSet::new();
visitor.calls.retain(|call| {
if FILTERED_NAMES.contains(&call.function_name.as_str()) {
return false;
}
seen.insert(call.function_name.clone())
});
Ok(visitor.calls)
}
fn receiver_to_string(expr: &syn::Expr) -> Option<String> {
match expr {
syn::Expr::Field(field_expr) => {
let member = match &field_expr.member {
syn::Member::Named(ident) => ident.to_string(),
syn::Member::Unnamed(idx) => idx.index.to_string(),
};
match receiver_to_string(&field_expr.base) {
Some(base) => Some(format!("{}.{}", base, member)),
None => Some(member),
}
}
syn::Expr::Path(path_expr) => Some(
path_expr
.path
.segments
.iter()
.map(|seg| seg.ident.to_string())
.collect::<Vec<_>>()
.join("::"),
),
syn::Expr::MethodCall(method_call) => {
let base = receiver_to_string(&method_call.receiver);
let method = method_call.method.to_string();
match base {
Some(b) => Some(format!("{}.{}()", b, method)),
None => Some(format!("{}()", method)),
}
}
_ => None,
}
}
struct CallVisitor {
calls: Vec<DetectedCall>,
}
impl<'ast> Visit<'ast> for CallVisitor {
fn visit_expr_call(&mut self, node: &'ast syn::ExprCall) {
if let syn::Expr::Path(expr_path) = &*node.func {
let segments = &expr_path.path.segments;
let len = segments.len();
if len == 1 {
let name = segments[0].ident.to_string();
self.calls.push(DetectedCall {
function_name: name,
call_type: CallType::FreeFunction,
});
} else if len >= 2 {
let type_name = segments[len - 2].ident.to_string();
let func_name = segments[len - 1].ident.to_string();
self.calls.push(DetectedCall {
function_name: func_name,
call_type: CallType::StaticMethod { type_name },
});
}
}
syn::visit::visit_expr_call(self, node);
}
fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
let name = node.method.to_string();
if !FILTERED_METHOD_NAMES.contains(&name.as_str()) {
let receiver = receiver_to_string(&node.receiver);
self.calls.push(DetectedCall {
function_name: name,
call_type: CallType::MethodCall { receiver },
});
}
syn::visit::visit_expr_method_call(self, node);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_free_function() {
let source = r#"fn test() { foo(1, 2); }"#;
let calls = detect_function_calls(source).unwrap();
assert!(calls
.iter()
.any(|c| c.function_name == "foo" && c.call_type == CallType::FreeFunction));
}
#[test]
fn test_detect_static_method() {
let source = r#"fn test() { MyStruct::do_something(x); }"#;
let calls = detect_function_calls(source).unwrap();
assert!(calls.iter().any(|c| c.function_name == "do_something"
&& c.call_type
== CallType::StaticMethod {
type_name: "MyStruct".to_string()
}));
}
#[test]
fn test_detect_method_call() {
let source = r#"fn test() { obj.method(arg); }"#;
let calls = detect_function_calls(source).unwrap();
assert!(calls.iter().any(|c| c.function_name == "method"
&& matches!(&c.call_type, CallType::MethodCall { receiver: Some(r) } if r == "obj")));
}
#[test]
fn test_detect_method_call_with_field_receiver() {
let source = r#"fn test() { ctx.accounts.process(); }"#;
let calls = detect_function_calls(source).unwrap();
assert!(calls.iter().any(|c| c.function_name == "process"
&& matches!(&c.call_type, CallType::MethodCall { receiver: Some(r) } if r == "ctx.accounts")));
}
#[test]
fn test_filters_common_names() {
let source = r#"fn test() { Ok(value); Some(x); vec![1,2]; foo(1); }"#;
let calls = detect_function_calls(source).unwrap();
assert!(!calls.iter().any(|c| c.function_name == "Ok"));
assert!(!calls.iter().any(|c| c.function_name == "Some"));
assert!(calls.iter().any(|c| c.function_name == "foo"));
}
#[test]
fn test_deduplicates() {
let source = r#"fn test() { foo(1); foo(2); foo(3); }"#;
let calls = detect_function_calls(source).unwrap();
let foo_count = calls.iter().filter(|c| c.function_name == "foo").count();
assert_eq!(foo_count, 1);
}
#[test]
fn test_nested_calls() {
let source = r#"fn test() { outer(inner(x)); }"#;
let calls = detect_function_calls(source).unwrap();
assert!(calls.iter().any(|c| c.function_name == "outer"));
assert!(calls.iter().any(|c| c.function_name == "inner"));
}
#[test]
fn test_chained_methods() {
let source = r#"fn test() { x.foo().bar().baz(); }"#;
let calls = detect_function_calls(source).unwrap();
assert!(calls.iter().any(|c| c.function_name == "foo"));
assert!(calls.iter().any(|c| c.function_name == "bar"));
assert!(calls.iter().any(|c| c.function_name == "baz"));
}
#[test]
fn test_closure_calls() {
let source = r#"fn test() { let f = |x| compute(x); items.iter().map(|i| transform(i)); }"#;
let calls = detect_function_calls(source).unwrap();
assert!(calls.iter().any(|c| c.function_name == "compute"));
assert!(calls.iter().any(|c| c.function_name == "transform"));
}
#[test]
fn test_body_only_fallback() {
let source = r#"let x = foo(1); bar(x);"#;
let calls = detect_function_calls(source).unwrap();
assert!(calls.iter().any(|c| c.function_name == "foo"));
assert!(calls.iter().any(|c| c.function_name == "bar"));
}
#[test]
fn test_long_path_static() {
let source = r#"fn test() { module::SubModule::create(arg); }"#;
let calls = detect_function_calls(source).unwrap();
assert!(calls.iter().any(|c| c.function_name == "create"
&& c.call_type
== CallType::StaticMethod {
type_name: "SubModule".to_string()
}));
}
}