use std::collections::{HashMap, HashSet};
use syn::visit::Visit;
use crate::batbelt::metadata::functions_source_code_metadata::FunctionSourceCodeMetadata;
use crate::batbelt::metadata::trait_metadata::TraitMetadata;
use crate::batbelt::metadata::MetadataId;
use crate::batbelt::parser::file_scope::FileScope;
use crate::batbelt::parser::type_resolver::{ResolvedType, TypeResolver};
#[derive(Clone, Debug, PartialEq)]
pub enum Resolution {
Internal(MetadataId),
External(String),
Unresolved(String),
}
#[derive(Clone, Debug)]
pub struct ResolvedCall {
pub function_name: String,
pub resolution: Resolution,
}
const FILTERED_METHOD_NAMES: &[&str] = &[
"unwrap", "expect", "clone", "to_string", "to_owned",
"iter", "into_iter", "map", "filter", "collect", "fold", "for_each", "find", "any", "all",
"push", "pop", "len", "is_empty", "contains", "get", "insert", "remove", "extend",
"ok_or", "ok_or_else", "map_err", "and_then", "or_else", "unwrap_or", "unwrap_or_else",
"as_ref", "as_mut", "borrow", "borrow_mut",
"into", "from", "try_into", "try_from",
"default", "to_vec", "as_slice", "as_str",
"is_some", "is_none", "is_ok", "is_err",
"trim", "trim_start", "trim_end", "split", "join", "replace", "starts_with", "ends_with",
"lines", "chars", "bytes",
"next", "enumerate", "skip", "take", "zip", "chain", "flat_map", "flatten",
"filter_map", "position", "count",
"sort", "sort_by", "sort_by_key", "dedup",
"change_context", "attach_printable", "report", "into_report",
"key", "to_account_info", "to_accounts", "to_account_infos", "to_account_metas",
"load", "load_mut", "load_init", "reload", "try_borrow_data", "try_borrow_mut_data",
"lamports", "data", "data_len", "data_is_empty", "owner", "executable", "rent_epoch",
"program_id", "signer", "signers", "cpi_accounts", "cpi_program",
"set_inner", "exit", "try_serialize", "try_deserialize",
"try_deserialize_unchecked", "try_accounts",
"checked_add", "checked_sub", "checked_mul", "checked_div", "checked_rem",
"saturating_add", "saturating_sub", "saturating_mul", "saturating_div",
"wrapping_add", "wrapping_sub", "wrapping_mul", "wrapping_div",
"overflowing_add", "overflowing_sub", "overflowing_mul",
"to_bytes", "as_array", "find_program_address", "create_program_address",
"to_le_bytes", "to_be_bytes", "from_le_bytes", "from_be_bytes",
];
const FILTERED_FUNCTION_NAMES: &[&str] = &[
"Ok", "Some", "Err", "None", "vec", "format", "println", "eprintln", "print", "eprint",
"panic", "todo", "unimplemented", "unreachable", "assert", "assert_eq", "assert_ne",
"debug_assert", "debug_assert_eq", "debug_assert_ne", "write", "writeln", "log",
"cfg", "include", "include_str", "include_bytes", "env", "option_env",
"concat", "stringify", "file", "line", "column", "module_path",
"Box", "Vec", "String", "Arc", "Rc", "Mutex", "RefCell",
"msg", "emit", "emit_cpi", "require", "require_eq", "require_neq", "require_keys_eq",
"require_keys_neq", "require_gt", "require_gte", "require_lt", "require_lte",
"invoke", "invoke_signed", "system_program",
];
pub struct CallResolver<'a> {
file_scope: &'a FileScope,
trait_metadata: &'a [TraitMetadata],
function_metadata: &'a [FunctionSourceCodeMetadata],
current_function_id: &'a str,
}
impl<'a> CallResolver<'a> {
pub fn new(
file_scope: &'a FileScope,
trait_metadata: &'a [TraitMetadata],
function_metadata: &'a [FunctionSourceCodeMetadata],
current_function_id: &'a str,
) -> Self {
Self {
file_scope,
trait_metadata,
function_metadata,
current_function_id,
}
}
pub fn resolve_function(&self, item_fn: &syn::ItemFn) -> Vec<ResolvedCall> {
let type_resolver = TypeResolver::new(self.file_scope);
let mut param_types: HashMap<String, ResolvedType> = HashMap::new();
for input in &item_fn.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
let name = pat_ident.ident.to_string();
let resolved = type_resolver.resolve(&pat_type.ty);
param_types.insert(name, resolved);
}
}
}
let mut context_accounts_types: HashMap<String, ResolvedType> = HashMap::new();
for input in &item_fn.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
let name = pat_ident.ident.to_string();
if let Some(inner) = type_resolver.resolve_context_inner(&pat_type.ty) {
context_accounts_types.insert(name, inner);
}
}
}
}
let mut visitor = CallVisitor {
calls: Vec::new(),
seen: HashSet::new(),
};
visitor.visit_block(&item_fn.block);
let current_fn_name = item_fn.sig.ident.to_string();
visitor
.calls
.into_iter()
.filter(|raw| raw.function_name != current_fn_name)
.map(|raw| {
let resolution = self.resolve_call(&raw, ¶m_types, &context_accounts_types);
ResolvedCall {
function_name: raw.function_name,
resolution,
}
})
.collect()
}
fn resolve_call(
&self,
raw: &RawCall,
param_types: &HashMap<String, ResolvedType>,
context_accounts_types: &HashMap<String, ResolvedType>,
) -> Resolution {
match &raw.call_type {
RawCallType::FreeFunction => self.resolve_free_function(&raw.function_name),
RawCallType::StaticMethod { type_name } => {
self.resolve_static_method(type_name, &raw.function_name)
}
RawCallType::MethodCall { receiver } => {
self.resolve_method_call(
receiver.as_deref(),
&raw.function_name,
param_types,
context_accounts_types,
)
}
}
}
fn resolve_free_function(&self, name: &str) -> Resolution {
let candidates_paths = self.file_scope.resolve_name_candidates(name);
let matching: Vec<_> = self
.function_metadata
.iter()
.filter(|f| f.name == name)
.collect();
if matching.is_empty() {
return Resolution::External(name.to_string());
}
if matching.len() == 1 {
return Resolution::Internal(matching[0].metadata_id.clone());
}
for candidate_path in &candidates_paths {
if candidate_path.starts_with("self::") {
if let Some(f) = matching
.iter()
.find(|f| f.path == self.file_scope.path)
{
return Resolution::Internal(f.metadata_id.clone());
}
continue;
}
if let Some(fragment) = path_to_fragment(candidate_path, name) {
if let Some(f) = matching.iter().find(|f| f.path.contains(&fragment)) {
return Resolution::Internal(f.metadata_id.clone());
}
}
}
log::warn!(
"CallResolver: unresolved ambiguous free function '{}' from file '{}' ({} candidates)",
name,
self.file_scope.path,
matching.len()
);
Resolution::Unresolved(name.to_string())
}
fn resolve_static_method(&self, type_name: &str, method: &str) -> Resolution {
let _full_type_path = self.file_scope.resolve_name(type_name);
let trait_signature = format!("{}::{}", type_name, method);
let matching_impls: Vec<_> = self
.trait_metadata
.iter()
.filter(|tm| tm.impl_to == type_name)
.collect();
for tm in &matching_impls {
for impl_fn in &tm.impl_functions {
if impl_fn.trait_signature == trait_signature {
return Resolution::Internal(impl_fn.function_source_code_metadata_id.clone());
}
}
}
Resolution::External(trait_signature)
}
fn resolve_method_call(
&self,
receiver: Option<&str>,
method: &str,
param_types: &HashMap<String, ResolvedType>,
context_accounts_types: &HashMap<String, ResolvedType>,
) -> Resolution {
let Some(receiver_str) = receiver else {
return Resolution::Unresolved(method.to_string());
};
if receiver_str == "self"
|| receiver_str == "self.mut"
|| receiver_str.starts_with("self.")
{
return self.resolve_self_method(method);
}
if let Some(dot_idx) = receiver_str.find('.') {
let root = &receiver_str[..dot_idx];
let rest = &receiver_str[dot_idx + 1..];
if rest == "accounts" {
if let Some(accounts_type) = context_accounts_types.get(root) {
return self.resolve_method_on_type(accounts_type, method);
}
}
}
if let Some(param_type) = param_types.get(receiver_str) {
return self.resolve_method_on_type(param_type, method);
}
Resolution::Unresolved(method.to_string())
}
fn resolve_self_method(&self, method: &str) -> Resolution {
let self_impl_type = self.trait_metadata.iter().find_map(|tm| {
let contains = tm
.impl_functions
.iter()
.any(|f| f.function_source_code_metadata_id == self.current_function_id);
if contains {
Some(tm.impl_to.clone())
} else {
None
}
});
let Some(impl_type) = self_impl_type else {
return Resolution::Unresolved(format!("self.{}", method));
};
let trait_signature = format!("{}::{}", impl_type, method);
for tm in self.trait_metadata {
if tm.impl_to == impl_type {
for f in &tm.impl_functions {
if f.trait_signature == trait_signature {
return Resolution::Internal(f.function_source_code_metadata_id.clone());
}
}
}
}
Resolution::Unresolved(trait_signature)
}
fn resolve_method_on_type(&self, ty: &ResolvedType, method: &str) -> Resolution {
let Some(type_name) = ty.type_name() else {
return Resolution::Unresolved(method.to_string());
};
let trait_signature = format!("{}::{}", type_name, method);
for tm in self.trait_metadata {
if tm.impl_to == type_name {
for f in &tm.impl_functions {
if f.trait_signature == trait_signature {
return Resolution::Internal(f.function_source_code_metadata_id.clone());
}
}
}
}
Resolution::External(trait_signature)
}
}
fn path_to_fragment(import_path: &str, function_name: &str) -> Option<String> {
let segments: Vec<&str> = import_path
.split("::")
.filter(|s| *s != "crate" && *s != "super" && *s != "self" && *s != function_name)
.collect();
if segments.is_empty() {
None
} else {
Some(segments.join("/"))
}
}
#[derive(Clone, Debug)]
struct RawCall {
function_name: String,
call_type: RawCallType,
}
#[derive(Clone, Debug)]
enum RawCallType {
FreeFunction,
StaticMethod { type_name: String },
MethodCall { receiver: Option<String> },
}
struct CallVisitor {
calls: Vec<RawCall>,
seen: HashSet<String>,
}
impl CallVisitor {
fn record(&mut self, call: RawCall) {
match &call.call_type {
RawCallType::FreeFunction | RawCallType::StaticMethod { .. } => {
if FILTERED_FUNCTION_NAMES.contains(&call.function_name.as_str()) {
return;
}
}
RawCallType::MethodCall { .. } => {
if FILTERED_METHOD_NAMES.contains(&call.function_name.as_str()) {
return;
}
}
}
let key = match &call.call_type {
RawCallType::FreeFunction => format!("free::{}", call.function_name),
RawCallType::StaticMethod { type_name } => {
format!("static::{}::{}", type_name, call.function_name)
}
RawCallType::MethodCall { receiver } => format!(
"method::{}::{}",
receiver.as_deref().unwrap_or("?"),
call.function_name
),
};
if self.seen.insert(key) {
self.calls.push(call);
}
}
}
impl<'ast> Visit<'ast> for CallVisitor {
fn visit_expr_call(&mut self, node: &'ast syn::ExprCall) {
if let syn::Expr::Path(expr_path) = &*node.func {
let segments = &expr_path.path.segments;
let len = segments.len();
if len == 1 {
let name = segments[0].ident.to_string();
self.record(RawCall {
function_name: name,
call_type: RawCallType::FreeFunction,
});
} else if len >= 2 {
let type_name = segments[len - 2].ident.to_string();
let func_name = segments[len - 1].ident.to_string();
self.record(RawCall {
function_name: func_name,
call_type: RawCallType::StaticMethod { type_name },
});
}
}
syn::visit::visit_expr_call(self, node);
}
fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
let method = node.method.to_string();
let receiver = receiver_to_string(&node.receiver);
self.record(RawCall {
function_name: method,
call_type: RawCallType::MethodCall { receiver },
});
syn::visit::visit_expr_method_call(self, node);
}
}
fn receiver_to_string(expr: &syn::Expr) -> Option<String> {
match expr {
syn::Expr::Path(path_expr) => {
let s = path_expr
.path
.segments
.iter()
.map(|seg| seg.ident.to_string())
.collect::<Vec<_>>()
.join("::");
Some(s)
}
syn::Expr::Field(field_expr) => {
let member = match &field_expr.member {
syn::Member::Named(ident) => ident.to_string(),
syn::Member::Unnamed(idx) => idx.index.to_string(),
};
match receiver_to_string(&field_expr.base) {
Some(base) => Some(format!("{}.{}", base, member)),
None => Some(member),
}
}
syn::Expr::MethodCall(method_call) => {
let base = receiver_to_string(&method_call.receiver);
let method = method_call.method.to_string();
match base {
Some(b) => Some(format!("{}.{}()", b, method)),
None => Some(format!("{}()", method)),
}
}
syn::Expr::Paren(p) => receiver_to_string(&p.expr),
syn::Expr::Reference(r) => receiver_to_string(&r.expr),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_fragment_simple() {
assert_eq!(
path_to_fragment("crate::instructions::admin::initialize::process", "process"),
Some("instructions/admin/initialize".to_string())
);
}
#[test]
fn test_path_fragment_external() {
assert_eq!(
path_to_fragment("anchor_lang::prelude::Pubkey", "Pubkey"),
Some("anchor_lang/prelude".to_string())
);
}
#[test]
fn test_path_fragment_empty() {
assert_eq!(path_to_fragment("crate::foo", "foo"), None);
}
#[test]
fn test_filter_filtered_function_name() {
let mut visitor = CallVisitor {
calls: vec![],
seen: HashSet::new(),
};
visitor.record(RawCall {
function_name: "Ok".to_string(),
call_type: RawCallType::FreeFunction,
});
visitor.record(RawCall {
function_name: "real_fn".to_string(),
call_type: RawCallType::FreeFunction,
});
assert_eq!(visitor.calls.len(), 1);
assert_eq!(visitor.calls[0].function_name, "real_fn");
}
#[test]
fn test_filter_filtered_method_name() {
let mut visitor = CallVisitor {
calls: vec![],
seen: HashSet::new(),
};
visitor.record(RawCall {
function_name: "key".to_string(),
call_type: RawCallType::MethodCall {
receiver: Some("account".to_string()),
},
});
visitor.record(RawCall {
function_name: "process".to_string(),
call_type: RawCallType::MethodCall {
receiver: Some("ctx.accounts".to_string()),
},
});
assert_eq!(visitor.calls.len(), 1);
assert_eq!(visitor.calls[0].function_name, "process");
}
#[test]
fn test_deduplication() {
let mut visitor = CallVisitor {
calls: vec![],
seen: HashSet::new(),
};
let raw = RawCall {
function_name: "foo".to_string(),
call_type: RawCallType::FreeFunction,
};
visitor.record(raw.clone());
visitor.record(raw.clone());
visitor.record(raw);
assert_eq!(visitor.calls.len(), 1);
}
#[test]
fn test_receiver_to_string_path() {
let expr: syn::Expr = syn::parse_str("foo").unwrap();
assert_eq!(receiver_to_string(&expr), Some("foo".to_string()));
}
#[test]
fn test_receiver_to_string_nested_field() {
let expr: syn::Expr = syn::parse_str("ctx.accounts").unwrap();
assert_eq!(receiver_to_string(&expr), Some("ctx.accounts".to_string()));
}
#[test]
fn test_receiver_to_string_self() {
let expr: syn::Expr = syn::parse_str("self").unwrap();
assert_eq!(receiver_to_string(&expr), Some("self".to_string()));
}
#[test]
fn test_receiver_to_string_self_field() {
let expr: syn::Expr = syn::parse_str("self.state").unwrap();
assert_eq!(receiver_to_string(&expr), Some("self.state".to_string()));
}
}