use std::collections::{HashMap, HashSet};
use syn::visit::Visit;
use crate::batbelt::metadata::functions_source_code_metadata::FunctionSourceCodeMetadata;
use crate::batbelt::metadata::trait_metadata::TraitMetadata;
use crate::batbelt::metadata::MetadataId;
use crate::batbelt::parser::file_scope::FileScope;
use crate::batbelt::parser::type_resolver::{ResolvedType, TypeResolver};
#[derive(Clone, Debug, PartialEq)]
pub enum Resolution {
Internal(MetadataId),
External(String),
Unresolved(String),
}
#[derive(Clone, Debug)]
pub struct ResolvedCall {
pub function_name: String,
pub resolution: Resolution,
}
const FILTERED_METHOD_NAMES: &[&str] = &[
"unwrap",
"expect",
"clone",
"to_string",
"to_owned",
"iter",
"into_iter",
"map",
"filter",
"collect",
"fold",
"for_each",
"find",
"any",
"all",
"push",
"pop",
"len",
"is_empty",
"contains",
"get",
"insert",
"remove",
"extend",
"ok_or",
"ok_or_else",
"map_err",
"and_then",
"or_else",
"unwrap_or",
"unwrap_or_else",
"as_ref",
"as_mut",
"borrow",
"borrow_mut",
"into",
"from",
"try_into",
"try_from",
"default",
"to_vec",
"as_slice",
"as_str",
"is_some",
"is_none",
"is_ok",
"is_err",
"trim",
"trim_start",
"trim_end",
"split",
"join",
"replace",
"starts_with",
"ends_with",
"lines",
"chars",
"bytes",
"next",
"enumerate",
"skip",
"take",
"zip",
"chain",
"flat_map",
"flatten",
"filter_map",
"position",
"count",
"sort",
"sort_by",
"sort_by_key",
"dedup",
"change_context",
"attach_printable",
"report",
"into_report",
"key",
"to_account_info",
"to_accounts",
"to_account_infos",
"to_account_metas",
"load",
"load_mut",
"load_init",
"reload",
"try_borrow_data",
"try_borrow_mut_data",
"lamports",
"data",
"data_len",
"data_is_empty",
"owner",
"executable",
"rent_epoch",
"program_id",
"signer",
"signers",
"cpi_accounts",
"cpi_program",
"set_inner",
"exit",
"try_serialize",
"try_deserialize",
"try_deserialize_unchecked",
"try_accounts",
"checked_add",
"checked_sub",
"checked_mul",
"checked_div",
"checked_rem",
"saturating_add",
"saturating_sub",
"saturating_mul",
"saturating_div",
"wrapping_add",
"wrapping_sub",
"wrapping_mul",
"wrapping_div",
"overflowing_add",
"overflowing_sub",
"overflowing_mul",
"to_bytes",
"as_array",
"find_program_address",
"create_program_address",
"to_le_bytes",
"to_be_bytes",
"from_le_bytes",
"from_be_bytes",
];
const FILTERED_FUNCTION_NAMES: &[&str] = &[
"Ok",
"Some",
"Err",
"None",
"vec",
"format",
"println",
"eprintln",
"print",
"eprint",
"panic",
"todo",
"unimplemented",
"unreachable",
"assert",
"assert_eq",
"assert_ne",
"debug_assert",
"debug_assert_eq",
"debug_assert_ne",
"write",
"writeln",
"log",
"cfg",
"include",
"include_str",
"include_bytes",
"env",
"option_env",
"concat",
"stringify",
"file",
"line",
"column",
"module_path",
"Box",
"Vec",
"String",
"Arc",
"Rc",
"Mutex",
"RefCell",
"msg",
"emit",
"emit_cpi",
"require",
"require_eq",
"require_neq",
"require_keys_eq",
"require_keys_neq",
"require_gt",
"require_gte",
"require_lt",
"require_lte",
"invoke",
"invoke_signed",
"system_program",
];
pub struct CallResolver<'a> {
file_scope: &'a FileScope,
trait_metadata: &'a [TraitMetadata],
function_metadata: &'a [FunctionSourceCodeMetadata],
current_function_id: &'a str,
}
impl<'a> CallResolver<'a> {
pub fn new(
file_scope: &'a FileScope,
trait_metadata: &'a [TraitMetadata],
function_metadata: &'a [FunctionSourceCodeMetadata],
current_function_id: &'a str,
) -> Self {
Self {
file_scope,
trait_metadata,
function_metadata,
current_function_id,
}
}
pub fn resolve_function(&self, item_fn: &syn::ItemFn) -> Vec<ResolvedCall> {
let type_resolver = TypeResolver::new(self.file_scope);
let mut param_types: HashMap<String, ResolvedType> = HashMap::new();
for input in &item_fn.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
let name = pat_ident.ident.to_string();
let resolved = type_resolver.resolve(&pat_type.ty);
param_types.insert(name, resolved);
}
}
}
let mut context_accounts_types: HashMap<String, ResolvedType> = HashMap::new();
for input in &item_fn.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
let name = pat_ident.ident.to_string();
if let Some(inner) = type_resolver.resolve_context_inner(&pat_type.ty) {
context_accounts_types.insert(name, inner);
}
}
}
}
let mut visitor = CallVisitor {
calls: Vec::new(),
seen: HashSet::new(),
};
visitor.visit_block(&item_fn.block);
let current_fn_name = item_fn.sig.ident.to_string();
visitor
.calls
.into_iter()
.filter(|raw| raw.function_name != current_fn_name)
.map(|raw| {
let resolution = self.resolve_call(&raw, ¶m_types, &context_accounts_types);
ResolvedCall {
function_name: raw.function_name,
resolution,
}
})
.collect()
}
fn resolve_call(
&self,
raw: &RawCall,
param_types: &HashMap<String, ResolvedType>,
context_accounts_types: &HashMap<String, ResolvedType>,
) -> Resolution {
match &raw.call_type {
RawCallType::FreeFunction => self.resolve_free_function(&raw.function_name),
RawCallType::StaticMethod { type_name } => {
self.resolve_static_method(type_name, &raw.function_name)
}
RawCallType::MethodCall { receiver } => self.resolve_method_call(
receiver.as_deref(),
&raw.function_name,
param_types,
context_accounts_types,
),
}
}
fn resolve_free_function(&self, name: &str) -> Resolution {
let candidates_paths = self.file_scope.resolve_name_candidates(name);
let matching: Vec<_> = self
.function_metadata
.iter()
.filter(|f| f.name == name)
.collect();
if matching.is_empty() {
return Resolution::External(name.to_string());
}
if matching.len() == 1 {
return Resolution::Internal(matching[0].metadata_id.clone());
}
for candidate_path in &candidates_paths {
if candidate_path.starts_with("self::") {
if let Some(f) = matching.iter().find(|f| f.path == self.file_scope.path) {
return Resolution::Internal(f.metadata_id.clone());
}
continue;
}
if let Some(fragment) = path_to_fragment(candidate_path, name) {
if let Some(f) = matching.iter().find(|f| f.path.contains(&fragment)) {
return Resolution::Internal(f.metadata_id.clone());
}
}
}
log::warn!(
"CallResolver: unresolved ambiguous free function '{}' from file '{}' ({} candidates)",
name,
self.file_scope.path,
matching.len()
);
Resolution::Unresolved(name.to_string())
}
fn resolve_static_method(&self, type_name: &str, method: &str) -> Resolution {
let _full_type_path = self.file_scope.resolve_name(type_name);
let module_path_fragment = format!("/{}/", type_name);
let module_matches: Vec<_> = self
.function_metadata
.iter()
.filter(|f| f.name == method && f.path.contains(&module_path_fragment))
.collect();
if module_matches.len() == 1 {
return Resolution::Internal(module_matches[0].metadata_id.clone());
}
if module_matches.is_empty() {
let module_file_fragment = format!("/{}.rs", type_name);
let file_matches: Vec<_> = self
.function_metadata
.iter()
.filter(|f| f.name == method && f.path.contains(&module_file_fragment))
.collect();
if file_matches.len() == 1 {
return Resolution::Internal(file_matches[0].metadata_id.clone());
}
}
let trait_signature = format!("{}::{}", type_name, method);
let matching_impls: Vec<_> = self
.trait_metadata
.iter()
.filter(|tm| tm.impl_to == type_name)
.collect();
for tm in &matching_impls {
for impl_fn in &tm.impl_functions {
if impl_fn.trait_signature == trait_signature {
return Resolution::Internal(impl_fn.function_source_code_metadata_id.clone());
}
}
}
Resolution::External(trait_signature)
}
fn resolve_method_call(
&self,
receiver: Option<&str>,
method: &str,
param_types: &HashMap<String, ResolvedType>,
context_accounts_types: &HashMap<String, ResolvedType>,
) -> Resolution {
let Some(receiver_str) = receiver else {
return Resolution::Unresolved(method.to_string());
};
if receiver_str == "self" || receiver_str == "self.mut" || receiver_str.starts_with("self.")
{
return self.resolve_self_method(method);
}
if let Some(dot_idx) = receiver_str.find('.') {
let root = &receiver_str[..dot_idx];
let rest = &receiver_str[dot_idx + 1..];
if rest == "accounts" {
if let Some(accounts_type) = context_accounts_types.get(root) {
return self.resolve_method_on_type(accounts_type, method);
}
}
}
if let Some(param_type) = param_types.get(receiver_str) {
return self.resolve_method_on_type(param_type, method);
}
Resolution::Unresolved(method.to_string())
}
fn resolve_self_method(&self, method: &str) -> Resolution {
let self_impl_type = self.trait_metadata.iter().find_map(|tm| {
let contains = tm
.impl_functions
.iter()
.any(|f| f.function_source_code_metadata_id == self.current_function_id);
if contains {
Some(tm.impl_to.clone())
} else {
None
}
});
let Some(impl_type) = self_impl_type else {
return Resolution::Unresolved(format!("self.{}", method));
};
let trait_signature = format!("{}::{}", impl_type, method);
for tm in self.trait_metadata {
if tm.impl_to == impl_type {
for f in &tm.impl_functions {
if f.trait_signature == trait_signature {
return Resolution::Internal(f.function_source_code_metadata_id.clone());
}
}
}
}
Resolution::Unresolved(trait_signature)
}
fn resolve_method_on_type(&self, ty: &ResolvedType, method: &str) -> Resolution {
let Some(type_name) = ty.type_name() else {
return Resolution::Unresolved(method.to_string());
};
let trait_signature = format!("{}::{}", type_name, method);
for tm in self.trait_metadata {
if tm.impl_to == type_name {
for f in &tm.impl_functions {
if f.trait_signature == trait_signature {
return Resolution::Internal(f.function_source_code_metadata_id.clone());
}
}
}
}
Resolution::External(trait_signature)
}
}
fn path_to_fragment(import_path: &str, function_name: &str) -> Option<String> {
let segments: Vec<&str> = import_path
.split("::")
.filter(|s| *s != "crate" && *s != "super" && *s != "self" && *s != function_name)
.collect();
if segments.is_empty() {
None
} else {
Some(segments.join("/"))
}
}
#[derive(Clone, Debug)]
struct RawCall {
function_name: String,
call_type: RawCallType,
}
#[derive(Clone, Debug)]
enum RawCallType {
FreeFunction,
StaticMethod { type_name: String },
MethodCall { receiver: Option<String> },
}
struct CallVisitor {
calls: Vec<RawCall>,
seen: HashSet<String>,
}
impl CallVisitor {
fn record(&mut self, call: RawCall) {
match &call.call_type {
RawCallType::FreeFunction | RawCallType::StaticMethod { .. } => {
if FILTERED_FUNCTION_NAMES.contains(&call.function_name.as_str()) {
return;
}
}
RawCallType::MethodCall { .. } => {
if FILTERED_METHOD_NAMES.contains(&call.function_name.as_str()) {
return;
}
}
}
let key = match &call.call_type {
RawCallType::FreeFunction => format!("free::{}", call.function_name),
RawCallType::StaticMethod { type_name } => {
format!("static::{}::{}", type_name, call.function_name)
}
RawCallType::MethodCall { receiver } => format!(
"method::{}::{}",
receiver.as_deref().unwrap_or("?"),
call.function_name
),
};
if self.seen.insert(key) {
self.calls.push(call);
}
}
}
impl<'ast> Visit<'ast> for CallVisitor {
fn visit_expr_call(&mut self, node: &'ast syn::ExprCall) {
if let syn::Expr::Path(expr_path) = &*node.func {
let segments = &expr_path.path.segments;
let len = segments.len();
if len == 1 {
let name = segments[0].ident.to_string();
self.record(RawCall {
function_name: name,
call_type: RawCallType::FreeFunction,
});
} else if len >= 2 {
let type_name = segments[len - 2].ident.to_string();
let func_name = segments[len - 1].ident.to_string();
self.record(RawCall {
function_name: func_name,
call_type: RawCallType::StaticMethod { type_name },
});
}
}
syn::visit::visit_expr_call(self, node);
}
fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
let method = node.method.to_string();
let receiver = receiver_to_string(&node.receiver);
self.record(RawCall {
function_name: method,
call_type: RawCallType::MethodCall { receiver },
});
syn::visit::visit_expr_method_call(self, node);
}
}
fn receiver_to_string(expr: &syn::Expr) -> Option<String> {
match expr {
syn::Expr::Path(path_expr) => {
let s = path_expr
.path
.segments
.iter()
.map(|seg| seg.ident.to_string())
.collect::<Vec<_>>()
.join("::");
Some(s)
}
syn::Expr::Field(field_expr) => {
let member = match &field_expr.member {
syn::Member::Named(ident) => ident.to_string(),
syn::Member::Unnamed(idx) => idx.index.to_string(),
};
match receiver_to_string(&field_expr.base) {
Some(base) => Some(format!("{}.{}", base, member)),
None => Some(member),
}
}
syn::Expr::MethodCall(method_call) => {
let base = receiver_to_string(&method_call.receiver);
let method = method_call.method.to_string();
match base {
Some(b) => Some(format!("{}.{}()", b, method)),
None => Some(format!("{}()", method)),
}
}
syn::Expr::Paren(p) => receiver_to_string(&p.expr),
syn::Expr::Reference(r) => receiver_to_string(&r.expr),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_fragment_simple() {
assert_eq!(
path_to_fragment("crate::instructions::admin::initialize::process", "process"),
Some("instructions/admin/initialize".to_string())
);
}
#[test]
fn test_path_fragment_external() {
assert_eq!(
path_to_fragment("anchor_lang::prelude::Pubkey", "Pubkey"),
Some("anchor_lang/prelude".to_string())
);
}
#[test]
fn test_path_fragment_empty() {
assert_eq!(path_to_fragment("crate::foo", "foo"), None);
}
#[test]
fn test_filter_filtered_function_name() {
let mut visitor = CallVisitor {
calls: vec![],
seen: HashSet::new(),
};
visitor.record(RawCall {
function_name: "Ok".to_string(),
call_type: RawCallType::FreeFunction,
});
visitor.record(RawCall {
function_name: "real_fn".to_string(),
call_type: RawCallType::FreeFunction,
});
assert_eq!(visitor.calls.len(), 1);
assert_eq!(visitor.calls[0].function_name, "real_fn");
}
#[test]
fn test_filter_filtered_method_name() {
let mut visitor = CallVisitor {
calls: vec![],
seen: HashSet::new(),
};
visitor.record(RawCall {
function_name: "key".to_string(),
call_type: RawCallType::MethodCall {
receiver: Some("account".to_string()),
},
});
visitor.record(RawCall {
function_name: "process".to_string(),
call_type: RawCallType::MethodCall {
receiver: Some("ctx.accounts".to_string()),
},
});
assert_eq!(visitor.calls.len(), 1);
assert_eq!(visitor.calls[0].function_name, "process");
}
#[test]
fn test_deduplication() {
let mut visitor = CallVisitor {
calls: vec![],
seen: HashSet::new(),
};
let raw = RawCall {
function_name: "foo".to_string(),
call_type: RawCallType::FreeFunction,
};
visitor.record(raw.clone());
visitor.record(raw.clone());
visitor.record(raw);
assert_eq!(visitor.calls.len(), 1);
}
#[test]
fn test_receiver_to_string_path() {
let expr: syn::Expr = syn::parse_str("foo").unwrap();
assert_eq!(receiver_to_string(&expr), Some("foo".to_string()));
}
#[test]
fn test_receiver_to_string_nested_field() {
let expr: syn::Expr = syn::parse_str("ctx.accounts").unwrap();
assert_eq!(receiver_to_string(&expr), Some("ctx.accounts".to_string()));
}
#[test]
fn test_receiver_to_string_self() {
let expr: syn::Expr = syn::parse_str("self").unwrap();
assert_eq!(receiver_to_string(&expr), Some("self".to_string()));
}
#[test]
fn test_receiver_to_string_self_field() {
let expr: syn::Expr = syn::parse_str("self.state").unwrap();
assert_eq!(receiver_to_string(&expr), Some("self.state".to_string()));
}
}