bat-cli 0.8.10

Blockchain Auditor Toolkit (BAT)
use syn::visit::Visit;

#[derive(Clone, Debug, PartialEq)]
pub struct DetectedCall {
    pub function_name: String,
    pub call_type: CallType,
}

#[derive(Clone, Debug, PartialEq)]
pub enum CallType {
    FreeFunction,
    StaticMethod { type_name: String },
    MethodCall { receiver: Option<String> },
}

const FILTERED_NAMES: &[&str] = &[
    "Ok", "Some", "Err", "None", "vec", "format", "println", "eprintln", "print", "eprint",
    "panic", "todo", "unimplemented", "unreachable", "assert", "assert_eq", "assert_ne",
    "debug_assert", "debug_assert_eq", "debug_assert_ne", "write", "writeln", "log",
    "cfg", "include", "include_str", "include_bytes", "env", "option_env",
    "concat", "stringify", "file", "line", "column", "module_path",
    "Box", "Vec", "String", "Arc", "Rc", "Mutex", "RefCell",
];

const FILTERED_METHOD_NAMES: &[&str] = &[
    "unwrap", "expect", "clone", "to_string", "to_owned",
    "iter", "into_iter", "map", "filter", "collect", "fold", "for_each", "find", "any", "all",
    "push", "pop", "len", "is_empty", "contains", "get", "insert", "remove", "extend",
    "ok_or", "ok_or_else", "map_err", "and_then", "or_else", "unwrap_or", "unwrap_or_else",
    "as_ref", "as_mut", "borrow", "borrow_mut",
    "into", "from", "try_into", "try_from",
    "default", "to_vec", "as_slice", "as_str",
    "change_context", "attach_printable",
    "is_some", "is_none", "is_ok", "is_err",
    "trim", "trim_start", "trim_end", "split", "join", "replace", "starts_with", "ends_with",
    "lines", "chars", "bytes",
    "next", "enumerate", "skip", "take", "zip", "chain", "flat_map", "flatten",
    "filter_map", "position", "count",
    "sort", "sort_by", "sort_by_key", "dedup",
];

pub fn detect_function_calls(function_source: &str) -> Result<Vec<DetectedCall>, String> {
    let item_fn = syn::parse_str::<syn::ItemFn>(function_source).or_else(|_| {
        let wrapped = format!("fn __wrapper() {{ {} }}", function_source);
        syn::parse_str::<syn::ItemFn>(&wrapped)
    });

    let item_fn = match item_fn {
        Ok(f) => f,
        Err(e) => return Err(format!("syn parse error: {}", e)),
    };

    let mut visitor = CallVisitor {
        calls: Vec::new(),
    };
    visitor.visit_item_fn(&item_fn);

    // Deduplicate by function_name
    let mut seen = std::collections::HashSet::new();
    visitor.calls.retain(|call| {
        if FILTERED_NAMES.contains(&call.function_name.as_str()) {
            return false;
        }
        seen.insert(call.function_name.clone())
    });

    Ok(visitor.calls)
}

/// Extracts the receiver chain as a dot-separated string (e.g. `ctx.accounts`).
fn receiver_to_string(expr: &syn::Expr) -> Option<String> {
    match expr {
        syn::Expr::Field(field_expr) => {
            let member = match &field_expr.member {
                syn::Member::Named(ident) => ident.to_string(),
                syn::Member::Unnamed(idx) => idx.index.to_string(),
            };
            match receiver_to_string(&field_expr.base) {
                Some(base) => Some(format!("{}.{}", base, member)),
                None => Some(member),
            }
        }
        syn::Expr::Path(path_expr) => {
            Some(path_expr.path.segments.iter()
                .map(|seg| seg.ident.to_string())
                .collect::<Vec<_>>()
                .join("::"))
        }
        // For method calls like `self.something()`, extract `self`
        syn::Expr::MethodCall(method_call) => {
            let base = receiver_to_string(&method_call.receiver);
            let method = method_call.method.to_string();
            match base {
                Some(b) => Some(format!("{}.{}()", b, method)),
                None => Some(format!("{}()", method)),
            }
        }
        _ => None,
    }
}

struct CallVisitor {
    calls: Vec<DetectedCall>,
}

impl<'ast> Visit<'ast> for CallVisitor {
    fn visit_expr_call(&mut self, node: &'ast syn::ExprCall) {
        if let syn::Expr::Path(expr_path) = &*node.func {
            let segments = &expr_path.path.segments;
            let len = segments.len();
            if len == 1 {
                let name = segments[0].ident.to_string();
                self.calls.push(DetectedCall {
                    function_name: name,
                    call_type: CallType::FreeFunction,
                });
            } else if len >= 2 {
                let type_name = segments[len - 2].ident.to_string();
                let func_name = segments[len - 1].ident.to_string();
                self.calls.push(DetectedCall {
                    function_name: func_name,
                    call_type: CallType::StaticMethod { type_name },
                });
            }
        }
        // Visit arguments recursively to find nested calls
        syn::visit::visit_expr_call(self, node);
    }

    fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
        let name = node.method.to_string();
        if !FILTERED_METHOD_NAMES.contains(&name.as_str()) {
            let receiver = receiver_to_string(&node.receiver);
            self.calls.push(DetectedCall {
                function_name: name,
                call_type: CallType::MethodCall { receiver },
            });
        }
        // Visit receiver and arguments recursively
        syn::visit::visit_expr_method_call(self, node);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_detect_free_function() {
        let source = r#"fn test() { foo(1, 2); }"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(calls.iter().any(|c| c.function_name == "foo"
            && c.call_type == CallType::FreeFunction));
    }

    #[test]
    fn test_detect_static_method() {
        let source = r#"fn test() { MyStruct::do_something(x); }"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(calls.iter().any(|c| c.function_name == "do_something"
            && c.call_type == CallType::StaticMethod {
                type_name: "MyStruct".to_string()
            }));
    }

    #[test]
    fn test_detect_method_call() {
        let source = r#"fn test() { obj.method(arg); }"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(calls.iter().any(|c| c.function_name == "method"
            && matches!(&c.call_type, CallType::MethodCall { receiver: Some(r) } if r == "obj")));
    }

    #[test]
    fn test_detect_method_call_with_field_receiver() {
        let source = r#"fn test() { ctx.accounts.process(); }"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(calls.iter().any(|c| c.function_name == "process"
            && matches!(&c.call_type, CallType::MethodCall { receiver: Some(r) } if r == "ctx.accounts")));
    }

    #[test]
    fn test_filters_common_names() {
        let source = r#"fn test() { Ok(value); Some(x); vec![1,2]; foo(1); }"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(!calls.iter().any(|c| c.function_name == "Ok"));
        assert!(!calls.iter().any(|c| c.function_name == "Some"));
        assert!(calls.iter().any(|c| c.function_name == "foo"));
    }

    #[test]
    fn test_deduplicates() {
        let source = r#"fn test() { foo(1); foo(2); foo(3); }"#;
        let calls = detect_function_calls(source).unwrap();
        let foo_count = calls.iter().filter(|c| c.function_name == "foo").count();
        assert_eq!(foo_count, 1);
    }

    #[test]
    fn test_nested_calls() {
        let source = r#"fn test() { outer(inner(x)); }"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(calls.iter().any(|c| c.function_name == "outer"));
        assert!(calls.iter().any(|c| c.function_name == "inner"));
    }

    #[test]
    fn test_chained_methods() {
        let source = r#"fn test() { x.foo().bar().baz(); }"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(calls.iter().any(|c| c.function_name == "foo"));
        assert!(calls.iter().any(|c| c.function_name == "bar"));
        assert!(calls.iter().any(|c| c.function_name == "baz"));
    }

    #[test]
    fn test_closure_calls() {
        let source = r#"fn test() { let f = |x| compute(x); items.iter().map(|i| transform(i)); }"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(calls.iter().any(|c| c.function_name == "compute"));
        assert!(calls.iter().any(|c| c.function_name == "transform"));
    }

    #[test]
    fn test_body_only_fallback() {
        let source = r#"let x = foo(1); bar(x);"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(calls.iter().any(|c| c.function_name == "foo"));
        assert!(calls.iter().any(|c| c.function_name == "bar"));
    }

    #[test]
    fn test_long_path_static() {
        let source = r#"fn test() { module::SubModule::create(arg); }"#;
        let calls = detect_function_calls(source).unwrap();
        assert!(calls.iter().any(|c| c.function_name == "create"
            && c.call_type == CallType::StaticMethod {
                type_name: "SubModule".to_string()
            }));
    }
}