use crate::priority::call_graph::{CallGraph, CallType, FunctionId};
use std::path::PathBuf;
trait FunctionalPipe<T> {
fn pipe<F, U>(self, f: F) -> U
where
F: FnOnce(T) -> U;
}
impl<T> FunctionalPipe<T> for T {
fn pipe<F, U>(self, f: F) -> U
where
F: FnOnce(T) -> U,
{
f(self)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum CallSiteType {
Static,
Instance { receiver_type: Option<String> },
TraitMethod {
trait_name: String,
receiver_type: Option<String>,
},
Indirect,
}
#[derive(Debug, Clone)]
pub struct UnresolvedCall {
pub caller: FunctionId,
pub callee_name: String,
pub call_type: CallType,
pub call_site_type: CallSiteType,
pub same_file_hint: bool, }
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResolutionOutcome {
Resolved(FunctionId),
IgnoredLibraryCall,
Unresolved,
}
use std::collections::HashMap;
pub struct CallResolver<'a> {
#[allow(dead_code)]
call_graph: &'a CallGraph,
current_file: &'a PathBuf,
function_index: HashMap<String, Vec<FunctionId>>,
}
impl<'a> CallResolver<'a> {
pub fn new(call_graph: &'a CallGraph, current_file: &'a PathBuf) -> Self {
let mut function_index: HashMap<String, Vec<FunctionId>> = HashMap::new();
let mut all_funcs: Vec<FunctionId> = call_graph.get_all_functions().cloned().collect();
all_funcs.sort();
for func_id in all_funcs {
let key = Self::normalize_path_prefix(&func_id.name);
function_index.entry(key).or_default().push(func_id.clone());
if let Some(simple_name) = func_id.name.split("::").last()
&& simple_name != func_id.name
{
function_index
.entry(simple_name.to_string())
.or_default()
.push(func_id.clone());
}
}
for funcs in function_index.values_mut() {
funcs.sort();
}
Self {
call_graph,
current_file,
function_index,
}
}
pub fn resolve_call(&self, call: &UnresolvedCall) -> Option<FunctionId> {
match self.resolve_call_outcome(call) {
ResolutionOutcome::Resolved(func_id) => Some(func_id),
ResolutionOutcome::IgnoredLibraryCall | ResolutionOutcome::Unresolved => None,
}
}
pub fn resolve_call_outcome(&self, call: &UnresolvedCall) -> ResolutionOutcome {
if let CallSiteType::TraitMethod { trait_name, .. } = &call.call_site_type {
if matches!(
trait_name.as_str(),
"Iterator"
| "Option"
| "Clone"
| "ToString"
| "Display"
| "Default"
| "Hash"
| "IteratorOrOption"
) {
return ResolutionOutcome::IgnoredLibraryCall;
}
}
if let CallSiteType::Instance {
receiver_type: None,
} = &call.call_site_type
&& Self::is_common_library_method(&call.callee_name)
{
return ResolutionOutcome::IgnoredLibraryCall;
}
let normalized_name = Self::normalize_path_prefix(&call.callee_name);
let candidates = self.function_index.get(&normalized_name).or_else(|| {
if let Some(simple_name) = call.callee_name.split("::").last() {
self.function_index.get(simple_name)
} else {
None
}
});
let Some(candidates) = candidates else {
return ResolutionOutcome::Unresolved;
};
let matching_candidates: Vec<FunctionId> = match &call.call_site_type {
CallSiteType::Static => {
candidates
.iter()
.filter(|func| {
Self::is_exact_match(&func.name, &normalized_name)
|| Self::is_exact_match(&func.name, &call.callee_name)
|| Self::is_qualified_match(&func.name, &normalized_name)
|| Self::is_qualified_match(&func.name, &call.callee_name)
})
.cloned()
.collect()
}
CallSiteType::Instance {
receiver_type: Some(recv_type),
} => {
let expected_name = format!(
"{}::{}",
recv_type,
call.callee_name
.split("::")
.last()
.unwrap_or(&call.callee_name)
);
candidates
.iter()
.filter(|func| {
func.name == expected_name
|| func.name.starts_with(&format!("{}::", recv_type))
})
.cloned()
.collect()
}
CallSiteType::Instance {
receiver_type: None,
} => {
if Self::is_common_library_method(&call.callee_name) {
return ResolutionOutcome::IgnoredLibraryCall;
}
if call.same_file_hint {
candidates
.iter()
.filter(|func| {
func.file == *self.current_file
&& Self::is_function_match(
func,
&normalized_name,
&call.callee_name,
)
})
.cloned()
.collect()
} else {
candidates
.iter()
.filter(|func| {
Self::is_function_match(func, &normalized_name, &call.callee_name)
})
.cloned()
.collect()
}
}
CallSiteType::TraitMethod { receiver_type, .. } => {
if let Some(recv_type) = receiver_type {
let expected_name = format!(
"{}::{}",
recv_type,
call.callee_name
.split("::")
.last()
.unwrap_or(&call.callee_name)
);
candidates
.iter()
.filter(|func| func.name == expected_name)
.cloned()
.collect()
} else if call.same_file_hint {
candidates
.iter()
.filter(|func| func.file == *self.current_file)
.cloned()
.collect()
} else {
return ResolutionOutcome::Unresolved;
}
}
CallSiteType::Indirect => {
candidates
.iter()
.filter(|func| {
Self::is_function_match(func, &normalized_name, &call.callee_name)
})
.cloned()
.collect()
}
};
if matching_candidates.is_empty() {
return ResolutionOutcome::Unresolved;
}
Self::select_best_candidate(matching_candidates, self.current_file, call.same_file_hint)
.map(ResolutionOutcome::Resolved)
.unwrap_or(ResolutionOutcome::Unresolved)
}
pub fn resolve_function_call(
all_functions: &[FunctionId],
callee_name: &str,
current_file: &PathBuf,
same_file_hint: bool,
) -> Option<FunctionId> {
let normalized_name = Self::normalize_path_prefix(callee_name);
let candidates: Vec<FunctionId> = all_functions
.iter()
.filter(|func| Self::is_function_match(func, &normalized_name, callee_name))
.cloned()
.collect();
if candidates.is_empty() {
return None;
}
Self::select_best_candidate(candidates, current_file, same_file_hint)
}
pub fn normalize_path_prefix(name: &str) -> String {
Self::strip_generic_params(name)
}
pub fn strip_generic_params(name: &str) -> String {
let without_turbofish = if let Some(pos) = name.find("::<") {
if let Some(end) = Self::find_matching_bracket(&name[pos + 3..]) {
format!("{}{}", &name[..pos], &name[pos + 3 + end + 1..])
} else {
name.to_string()
}
} else {
name.to_string()
};
if let Some(pos) = without_turbofish.find('<') {
if let Some(end) = Self::find_matching_bracket(&without_turbofish[pos + 1..]) {
format!(
"{}{}",
&without_turbofish[..pos],
&without_turbofish[pos + 1 + end + 1..]
)
} else {
without_turbofish
}
} else {
without_turbofish
}
}
fn find_matching_bracket(s: &str) -> Option<usize> {
let mut depth = 1;
for (i, ch) in s.chars().enumerate() {
match ch {
'<' => depth += 1,
'>' => {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
_ => {}
}
}
None
}
fn select_best_candidate(
candidates: Vec<FunctionId>,
current_file: &PathBuf,
same_file_hint: bool,
) -> Option<FunctionId> {
if candidates.len() == 1 {
return candidates.into_iter().next();
}
if !same_file_hint && candidates.len() > 1 {
let all_same_name = candidates.iter().all(|f| f.name == candidates[0].name);
if all_same_name {
let unique_files: std::collections::HashSet<_> =
candidates.iter().map(|f| &f.file).collect();
if unique_files.len() == candidates.len() {
return None;
}
}
}
let result = candidates
.pipe(|funcs| Self::apply_same_file_preference(funcs, current_file, same_file_hint))
.pipe(Self::apply_qualification_preference)
.pipe(Self::apply_generic_preference);
result.into_iter().next()
}
fn apply_same_file_preference(
candidates: Vec<FunctionId>,
current_file: &PathBuf,
same_file_hint: bool,
) -> Vec<FunctionId> {
if !same_file_hint {
return candidates;
}
let same_file_matches: Vec<FunctionId> = candidates
.iter()
.filter(|func| &func.file == current_file)
.cloned()
.collect();
if same_file_matches.is_empty() {
candidates
} else {
same_file_matches
}
}
fn apply_qualification_preference(candidates: Vec<FunctionId>) -> Vec<FunctionId> {
if candidates.len() <= 1 {
return candidates;
}
let min_qualification = candidates
.iter()
.map(|func| Self::calculate_qualification_score(&func.name))
.min()
.unwrap_or(0);
candidates
.into_iter()
.filter(|func| Self::calculate_qualification_score(&func.name) == min_qualification)
.collect()
}
fn apply_generic_preference(candidates: Vec<FunctionId>) -> Vec<FunctionId> {
if candidates.len() <= 1 {
return candidates;
}
let non_generic: Vec<FunctionId> = candidates
.iter()
.filter(|func| !Self::is_generic_function(&func.name))
.cloned()
.collect();
if non_generic.is_empty() {
candidates
} else {
non_generic
}
}
fn calculate_qualification_score(name: &str) -> usize {
let qualification_level = name.matches("::").count();
let has_impl = name.contains("<") && name.contains(">");
qualification_level + if has_impl { 1000 } else { 0 }
}
fn is_generic_function(name: &str) -> bool {
name.contains("<") && name.contains(">")
}
pub fn is_function_match(
func: &FunctionId,
normalized_name: &str,
original_name: &str,
) -> bool {
let func_name = &func.name;
if Self::is_exact_match(func_name, normalized_name)
|| Self::is_exact_match(func_name, original_name)
{
return true;
}
if Self::is_qualified_match(func_name, normalized_name)
|| Self::is_qualified_match(func_name, original_name)
{
return true;
}
Self::is_base_name_match(func_name, normalized_name)
|| Self::is_base_name_match(func_name, original_name)
}
fn is_exact_match(func_name: &str, search_name: &str) -> bool {
func_name == search_name
}
fn is_qualified_match(func_name: &str, search_name: &str) -> bool {
func_name.ends_with(&format!("::{}", search_name))
}
fn is_base_name_match(func_name: &str, search_name: &str) -> bool {
if let Some(pos) = func_name.rfind("::") {
let base_name = &func_name[pos + 2..];
return base_name == search_name;
}
false
}
pub fn extract_impl_type_from_caller(caller_name: &str) -> Option<String> {
if let Some(pos) = caller_name.rfind("::") {
let prefix = &caller_name[..pos];
if !prefix.contains("::") && prefix.chars().next()?.is_uppercase() {
return Some(prefix.to_string());
}
}
None
}
pub fn classify_call_type(name: &str) -> CallType {
if name.starts_with("self.") {
CallType::Delegate
} else {
CallType::Direct
}
}
pub fn resolve_self_type(name: &str, current_impl_type: &Option<String>) -> String {
if let Some(impl_type) = current_impl_type {
name.replace("Self", impl_type)
} else {
name.to_string()
}
}
pub fn is_same_file_call(name: &str, current_impl_type: &Option<String>) -> bool {
if !name.contains("::") && !name.starts_with("self.") {
return true;
}
if name.contains("Self::") && current_impl_type.is_some() {
return true;
}
false
}
pub fn is_self_receiver(receiver: &syn::Expr) -> bool {
matches!(receiver, syn::Expr::Path(path) if path.path.is_ident("self"))
}
pub fn construct_method_name(
receiver_type: Option<String>,
method_name: &str,
current_impl_type: &Option<String>,
) -> String {
if let Some(recv_type) = receiver_type {
if recv_type == "Self"
&& let Some(impl_type) = current_impl_type
{
return format!("{}::{}", impl_type, method_name);
}
format!("{}::{}", recv_type, method_name)
} else {
method_name.to_string()
}
}
pub fn is_std_trait_method(method_name: &str) -> bool {
matches!(
method_name,
"any" | "all" | "map" | "filter" | "fold" | "reduce" |
"collect" | "find" | "position" | "enumerate" | "zip" |
"chain" | "flat_map" | "flatten" | "skip" | "take" |
"cloned" | "copied" | "cycle" | "rev" | "peekable" |
"for_each" | "nth" | "last" | "step_by" | "scan" |
"fuse" | "inspect" | "partition" | "try_fold" | "try_for_each" |
"unwrap" | "expect" | "unwrap_or" | "unwrap_or_else" |
"and_then" | "or_else" | "is_some" | "is_none" |
"is_ok" | "is_err" | "as_ref" | "as_mut" | "ok" | "err" |
"transpose" | "unwrap_or_default" |
"clone" | "to_string" | "to_owned" | "into" | "from" |
"default" | "eq" | "ne" | "cmp" | "partial_cmp" |
"hash" | "fmt" | "display"
)
}
pub fn is_common_library_method(method_name: &str) -> bool {
Self::is_std_trait_method(method_name)
|| matches!(
method_name,
"iter" | "iter_mut" | "into_iter" | "is_empty" | "len" | "capacity" |
"contains" | "get" | "first" | "last" | "as_slice" | "as_str" |
"push" | "pop" | "insert" | "remove" | "clear" | "extend" |
"retain" | "sort" | "sort_by" | "dedup" |
"as_deref" | "as_deref_mut"
)
}
pub fn infer_trait_name(method_name: &str) -> String {
match method_name {
"any" | "all" | "filter" | "fold" | "reduce" | "collect" | "find" | "position"
| "enumerate" | "zip" | "chain" | "flat_map" | "flatten" | "skip" | "take"
| "cloned" | "copied" | "cycle" | "rev" | "peekable" | "for_each" | "nth" | "last"
| "step_by" | "scan" | "fuse" | "inspect" | "partition" | "try_fold"
| "try_for_each" => "Iterator".to_string(),
"map" => "IteratorOrOption".to_string(),
"unwrap" | "expect" | "unwrap_or" | "unwrap_or_else" | "and_then" | "or_else"
| "is_some" | "is_none" | "is_ok" | "is_err" | "as_ref" | "as_mut" | "ok" | "err"
| "transpose" | "unwrap_or_default" => "Option".to_string(),
"clone" => "Clone".to_string(),
"to_string" | "display" => "ToString".to_string(),
"fmt" => "Display".to_string(),
"default" => "Default".to_string(),
"hash" => "Hash".to_string(),
_ => "Unknown".to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_functional_refactoring_integration() {
let current_file = PathBuf::from("test.rs");
let functions = vec![
FunctionId::new(current_file.clone(), "simple_func".to_string(), 10),
FunctionId::new(current_file.clone(), "module::complex_func".to_string(), 20),
FunctionId::new(PathBuf::from("other.rs"), "other_func".to_string(), 30),
];
let result =
CallResolver::resolve_function_call(&functions, "simple_func", ¤t_file, true);
assert!(result.is_some());
let resolved = result.unwrap();
assert_eq!(resolved.name, "simple_func");
assert_eq!(resolved.file, current_file);
let result_no_hint =
CallResolver::resolve_function_call(&functions, "simple_func", ¤t_file, false);
assert!(result_no_hint.is_some());
assert_eq!(result_no_hint.unwrap().name, "simple_func");
}
#[test]
fn test_normalize_path_prefix() {
assert_eq!(
CallResolver::normalize_path_prefix("crate::module::func"),
"crate::module::func"
);
assert_eq!(
CallResolver::normalize_path_prefix("self::func"),
"self::func"
);
assert_eq!(
CallResolver::normalize_path_prefix("super::func"),
"super::func"
);
assert_eq!(CallResolver::normalize_path_prefix("func"), "func");
}
#[test]
fn test_is_base_name_match() {
assert!(CallResolver::is_base_name_match(
"MyStruct::method",
"method"
));
assert!(CallResolver::is_base_name_match(
"module::MyStruct::method",
"method"
));
assert!(CallResolver::is_base_name_match(
"module::function",
"function"
));
assert!(!CallResolver::is_base_name_match(
"MyStruct::method",
"other"
));
}
#[test]
fn test_classify_call_type() {
assert_eq!(
CallResolver::classify_call_type("module::func"),
CallType::Direct
);
assert_eq!(
CallResolver::classify_call_type("Type::method"),
CallType::Direct
);
assert_eq!(
CallResolver::classify_call_type("self.method"),
CallType::Delegate
);
assert_eq!(CallResolver::classify_call_type("func"), CallType::Direct);
}
#[test]
fn test_resolve_self_type() {
let impl_type = Some("MyStruct".to_string());
assert_eq!(
CallResolver::resolve_self_type("Self::new", &impl_type),
"MyStruct::new"
);
assert_eq!(
CallResolver::resolve_self_type("Self", &impl_type),
"MyStruct"
);
let no_impl = None;
assert_eq!(
CallResolver::resolve_self_type("Self::new", &no_impl),
"Self::new"
);
}
#[test]
fn test_is_same_file_call() {
let impl_type = Some("MyStruct".to_string());
assert!(CallResolver::is_same_file_call("simple_func", &None));
assert!(CallResolver::is_same_file_call("Self::method", &impl_type));
assert!(!CallResolver::is_same_file_call("module::func", &None));
assert!(!CallResolver::is_same_file_call("self.method", &None));
}
#[test]
fn test_extract_impl_type_from_caller() {
assert_eq!(
CallResolver::extract_impl_type_from_caller("MyStruct::method"),
Some("MyStruct".to_string())
);
assert_eq!(
CallResolver::extract_impl_type_from_caller("module::MyStruct::method"),
None
);
assert_eq!(
CallResolver::extract_impl_type_from_caller("function"),
None
);
}
#[test]
fn test_complex_matching_scenarios() {
let current_file = PathBuf::from("test.rs");
let simple_func = FunctionId::new(current_file.clone(), "calculate".to_string(), 10);
let qualified_func =
FunctionId::new(current_file.clone(), "utils::calculate".to_string(), 20);
let method_func = FunctionId::new(
current_file.clone(),
"Calculator::calculate".to_string(),
30,
);
let functions = vec![
qualified_func.clone(),
method_func.clone(),
simple_func.clone(),
];
let result =
CallResolver::resolve_function_call(&functions, "calculate", ¤t_file, true);
assert!(result.is_some());
assert_eq!(result.unwrap().name, "calculate");
}
#[test]
fn test_functional_pipeline_composition() {
let current_file = PathBuf::from("test.rs");
let functions = vec![
FunctionId::new(current_file.clone(), "func".to_string(), 10),
FunctionId::new(current_file.clone(), "mod::func".to_string(), 20),
FunctionId::new(current_file.clone(), "deep::mod::func".to_string(), 30),
];
let result = CallResolver::apply_qualification_preference(functions.clone());
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "func");
let generic_functions = vec![
FunctionId::new(current_file.clone(), "regular_func".to_string(), 10),
FunctionId::new(current_file.clone(), "generic_func<T>".to_string(), 20),
];
let result = CallResolver::apply_generic_preference(generic_functions.clone());
assert_eq!(result.len(), 1);
assert_eq!(result[0].name, "regular_func"); }
#[test]
fn test_strip_generic_params() {
assert_eq!(CallResolver::strip_generic_params("foo<T>"), "foo");
assert_eq!(CallResolver::strip_generic_params("bar::<Type>"), "bar");
assert_eq!(
CallResolver::strip_generic_params("func<Vec<String>>"),
"func"
);
assert_eq!(CallResolver::strip_generic_params("map<K, V>"), "map");
assert_eq!(
CallResolver::strip_generic_params("module::function<T>"),
"module::function"
);
assert_eq!(
CallResolver::strip_generic_params("Type::method<T, U>"),
"Type::method"
);
assert_eq!(
CallResolver::strip_generic_params("simple_function"),
"simple_function"
);
assert_eq!(
CallResolver::strip_generic_params("complex<HashMap<String, Vec<i32>>>"),
"complex"
);
}
#[test]
fn test_find_matching_bracket() {
assert_eq!(CallResolver::find_matching_bracket("T>"), Some(1));
assert_eq!(
CallResolver::find_matching_bracket("Vec<String>>"),
Some(11)
);
assert_eq!(
CallResolver::find_matching_bracket("HashMap<String, Vec<i32>>>"),
Some(25)
);
assert_eq!(CallResolver::find_matching_bracket("T"), None);
assert_eq!(CallResolver::find_matching_bracket("T<U"), None);
}
}