use super::call_graph::{CallGraph, FunctionId};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CallerType {
Production,
Test,
}
#[derive(Debug, Clone, Default)]
pub struct ClassifiedCallers {
pub production: Vec<String>,
pub test: Vec<String>,
pub production_count: usize,
pub test_count: usize,
}
impl ClassifiedCallers {
pub fn new() -> Self {
Self::default()
}
pub fn total_count(&self) -> usize {
self.production_count + self.test_count
}
}
pub fn classify_caller(caller: &str, call_graph: Option<&CallGraph>) -> CallerType {
if let Some(cg) = call_graph {
if let Some(func_id) = parse_caller_to_func_id(caller) {
if cg.is_test_function(&func_id) {
return CallerType::Test;
}
if cg.is_test_helper(&func_id) {
return CallerType::Test;
}
if is_test_function_by_name(&func_id.name, cg) {
return CallerType::Test;
}
} else {
if is_test_function_by_name(caller, cg) {
return CallerType::Test;
}
}
}
classify_by_heuristics(caller)
}
fn is_test_function_by_name(name: &str, call_graph: &CallGraph) -> bool {
let suffix_pattern = format!("::{}", name);
call_graph.get_all_functions().any(|func_id| {
let matches = func_id.name == name || func_id.name.ends_with(&suffix_pattern);
matches && call_graph.is_test_function(func_id)
})
}
fn parse_caller_to_func_id(caller: &str) -> Option<FunctionId> {
if caller.contains("::") {
let parts: Vec<&str> = caller.rsplitn(2, "::").collect();
if parts.len() == 2 {
let func_name = parts[0];
let path_or_module = parts[1];
if path_or_module.contains('/') || path_or_module.ends_with(".rs") {
return Some(FunctionId::new(
std::path::PathBuf::from(path_or_module),
func_name.to_string(),
0, ));
}
return Some(FunctionId::new(
std::path::PathBuf::from(format!("{}.rs", path_or_module.replace("::", "/"))),
func_name.to_string(),
0,
));
}
}
if caller.contains(':') && !caller.contains("::") {
let parts: Vec<&str> = caller.rsplitn(2, ':').collect();
if parts.len() == 2 {
let func_name = parts[0];
let file_path = parts[1];
if file_path.ends_with(".rs")
|| file_path.ends_with(".py")
|| file_path.ends_with(".js")
|| file_path.ends_with(".ts")
{
return Some(FunctionId::new(
std::path::PathBuf::from(file_path),
func_name.to_string(),
0,
));
}
}
}
None
}
pub fn classify_by_heuristics(caller: &str) -> CallerType {
let caller_lower = caller.to_lowercase();
let path_patterns = [
"/tests/",
"/test/",
"::tests::",
"::test::",
":test:", ":tests:", ];
for pattern in path_patterns {
if caller_lower.contains(pattern) {
return CallerType::Test;
}
}
if is_test_file_path(&caller_lower) {
return CallerType::Test;
}
let func_name = extract_function_name(&caller_lower);
let prefix_patterns = [
"test_", "tests_", "should_", "it_", "spec_", "verify_", "when_", "given_", "mock_", "stub_", "fake_", "fixture_", ];
for pattern in prefix_patterns {
if func_name.starts_with(pattern) {
return CallerType::Test;
}
}
let suffix_patterns = ["_test", "_tests", "_spec", "_mock", "_stub", "_fixture"];
for pattern in suffix_patterns {
if func_name.ends_with(pattern) {
return CallerType::Test;
}
}
let word_patterns = [
"_test_", "_spec_", "_assert_", "_expect_", "_setup_", "_teardown_", ];
for pattern in word_patterns {
if func_name.contains(pattern) {
return CallerType::Test;
}
}
CallerType::Production
}
fn is_test_file_path(caller: &str) -> bool {
let file_part = caller.split(':').next().unwrap_or("");
let file_name = file_part.rsplit('/').next().unwrap_or(file_part);
if file_name.starts_with("test_") && file_name.ends_with(".rs") {
return true;
}
if file_name.ends_with("_test.rs") || file_name.ends_with("_tests.rs") {
return true;
}
if file_part.contains("/tests/") || file_part.starts_with("tests/") {
return true;
}
false
}
fn extract_function_name(caller: &str) -> &str {
caller
.rsplit("::")
.next()
.unwrap_or(caller)
.rsplit(':')
.next()
.unwrap_or(caller)
.rsplit('/')
.next()
.unwrap_or(caller)
}
pub fn classify_callers<'a>(
callers: impl Iterator<Item = &'a String>,
call_graph: Option<&CallGraph>,
) -> ClassifiedCallers {
let mut result = ClassifiedCallers::new();
for caller in callers {
match classify_caller(caller, call_graph) {
CallerType::Production => {
result.production.push(caller.clone());
result.production_count += 1;
}
CallerType::Test => {
result.test.push(caller.clone());
result.test_count += 1;
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_by_name_patterns() {
assert_eq!(classify_by_heuristics("test_parse_array"), CallerType::Test);
assert_eq!(
classify_by_heuristics("should_reflow_long_lines"),
CallerType::Test
);
assert_eq!(
classify_by_heuristics("it_formats_correctly"),
CallerType::Test
);
assert_eq!(classify_by_heuristics("spec_overflow"), CallerType::Test);
assert_eq!(classify_by_heuristics("verify_output"), CallerType::Test);
assert_eq!(
classify_by_heuristics("mock_database_connection"),
CallerType::Test
);
assert_eq!(classify_by_heuristics("stub_api_client"), CallerType::Test);
assert_eq!(
classify_by_heuristics("fixture_user_data"),
CallerType::Test
);
assert_eq!(classify_by_heuristics("parse_array_test"), CallerType::Test);
assert_eq!(
classify_by_heuristics("overflow_handler_spec"),
CallerType::Test
);
assert_eq!(
classify_by_heuristics("process_file"),
CallerType::Production
);
assert_eq!(classify_by_heuristics("main"), CallerType::Production);
assert_eq!(
classify_by_heuristics("parse_tokens"),
CallerType::Production
);
assert_eq!(
classify_by_heuristics("handle_request"),
CallerType::Production
);
}
#[test]
fn test_classify_by_path_patterns() {
assert_eq!(
classify_by_heuristics("src/tests/helpers::create_mock"),
CallerType::Test
);
assert_eq!(
classify_by_heuristics("module::tests::test_function"),
CallerType::Test
);
assert_eq!(
classify_by_heuristics("crate::test::helpers::setup"),
CallerType::Test
);
}
#[test]
fn test_classify_production_functions() {
assert_eq!(
classify_by_heuristics("attest_function"),
CallerType::Production
);
assert_eq!(
classify_by_heuristics("contest_winner"),
CallerType::Production
);
assert_eq!(
classify_by_heuristics("latest_version"),
CallerType::Production
);
assert_eq!(
classify_by_heuristics("testing_mode_check"),
CallerType::Production
);
}
#[test]
fn test_classify_callers_separates_correctly() {
let callers = [
"test_parse_array".to_string(),
"process_file".to_string(),
"should_format".to_string(),
"main".to_string(),
"verify_output".to_string(),
];
let result = classify_callers(callers.iter(), None);
assert_eq!(result.production_count, 2);
assert_eq!(result.test_count, 3);
assert!(result.production.contains(&"process_file".to_string()));
assert!(result.production.contains(&"main".to_string()));
assert!(result.test.contains(&"test_parse_array".to_string()));
assert!(result.test.contains(&"should_format".to_string()));
assert!(result.test.contains(&"verify_output".to_string()));
}
#[test]
fn test_extract_function_name() {
assert_eq!(extract_function_name("module::function"), "function");
assert_eq!(extract_function_name("path/to/file::function"), "function");
assert_eq!(extract_function_name("function"), "function");
assert_eq!(extract_function_name("a::b::c::function"), "function");
}
#[test]
fn test_classify_bdd_patterns() {
assert_eq!(
classify_by_heuristics("when_user_clicks_button"),
CallerType::Test
);
assert_eq!(
classify_by_heuristics("given_valid_input"),
CallerType::Test
);
}
#[test]
fn test_classify_word_boundary_patterns() {
assert_eq!(classify_by_heuristics("user_test_helper"), CallerType::Test);
assert_eq!(classify_by_heuristics("setup_test_data"), CallerType::Test);
assert_eq!(
classify_by_heuristics("assert_valid_output"),
CallerType::Production
); }
#[test]
fn test_classified_callers_total_count() {
let mut result = ClassifiedCallers::new();
result.production_count = 5;
result.test_count = 10;
assert_eq!(result.total_count(), 15);
}
#[test]
fn test_parse_single_colon_format() {
let func_id = parse_caller_to_func_id("overflow.rs:inline_table_containing_array");
assert!(func_id.is_some());
let id = func_id.unwrap();
assert_eq!(id.name, "inline_table_containing_array");
assert_eq!(id.file.to_string_lossy(), "overflow.rs");
}
#[test]
fn test_parse_double_colon_format() {
let func_id = parse_caller_to_func_id("overflow::test::inline_table");
assert!(func_id.is_some());
let id = func_id.unwrap();
assert_eq!(id.name, "inline_table");
}
#[test]
fn test_is_test_file_path() {
assert!(is_test_file_path("test_overflow.rs:some_func"));
assert!(is_test_file_path("path/to/test_helpers.rs:setup"));
assert!(is_test_file_path("overflow_test.rs:verify"));
assert!(is_test_file_path("tests/integration.rs:test_flow"));
assert!(!is_test_file_path("overflow.rs:reflow_arrays"));
assert!(!is_test_file_path("src/main.rs:main"));
assert!(!is_test_file_path("formatting.rs:process"));
}
#[test]
fn test_extract_function_name_single_colon() {
assert_eq!(extract_function_name("file.rs:function"), "function");
assert_eq!(
extract_function_name("overflow.rs:inline_table"),
"inline_table"
);
}
#[test]
fn test_is_test_function_by_name_with_call_graph() {
use std::path::PathBuf;
let mut call_graph = CallGraph::new();
let test_fn = FunctionId::new(
PathBuf::from("overflow.rs"),
"inline_table_containing_array".to_string(),
100,
);
call_graph.add_function(
test_fn.clone(),
false, true, 5,
10,
);
let prod_fn = FunctionId::new(
PathBuf::from("overflow.rs"),
"reflow_arrays".to_string(),
50,
);
call_graph.add_function(
prod_fn.clone(),
true, false, 10,
25,
);
assert!(is_test_function_by_name(
"inline_table_containing_array",
&call_graph
));
assert!(!is_test_function_by_name("reflow_arrays", &call_graph));
assert!(!is_test_function_by_name("unknown_function", &call_graph));
assert_eq!(
classify_caller("inline_table_containing_array", Some(&call_graph)),
CallerType::Test
);
assert_eq!(
classify_caller("reflow_arrays", Some(&call_graph)),
CallerType::Production
);
}
#[test]
fn test_path_mismatch_falls_back_to_name_lookup() {
use std::path::PathBuf;
let mut call_graph = CallGraph::new();
let test_fn = FunctionId::new(
PathBuf::from("./src/formatting/overflow.rs"), "vertical_with_comment_stays_vertical".to_string(),
450,
);
call_graph.add_function(
test_fn.clone(),
false, true, 5,
10,
);
let caller = "overflow.rs:vertical_with_comment_stays_vertical";
assert_eq!(
classify_caller(caller, Some(&call_graph)),
CallerType::Test,
"Path mismatch should fall back to name-based lookup"
);
}
#[test]
fn test_prod_function_with_path_mismatch_stays_production() {
use std::path::PathBuf;
let mut call_graph = CallGraph::new();
let prod_fn = FunctionId::new(
PathBuf::from("./src/formatting/overflow.rs"),
"reflow_arrays".to_string(),
100,
);
call_graph.add_function(
prod_fn.clone(),
true, false, 10,
25,
);
let caller = "overflow.rs:reflow_arrays";
assert_eq!(
classify_caller(caller, Some(&call_graph)),
CallerType::Production,
"Production functions should remain Production even with path mismatch"
);
}
#[test]
fn test_module_qualified_name_matching() {
use std::path::PathBuf;
let mut call_graph = CallGraph::new();
let test_fn = FunctionId::new(
PathBuf::from("./src/formatting/overflow.rs"),
"test::short_array_not_reflowed".to_string(), 650,
);
call_graph.add_function(
test_fn.clone(),
false, true, 3,
15,
);
let caller = "overflow.rs:short_array_not_reflowed";
assert_eq!(
classify_caller(caller, Some(&call_graph)),
CallerType::Test,
"Module-qualified name (test::func) should match base name (func)"
);
}
#[test]
fn test_is_test_function_by_name_with_module_prefix() {
use std::path::PathBuf;
let mut call_graph = CallGraph::new();
let test_fn = FunctionId::new(
PathBuf::from("overflow.rs"),
"test::vertical_stays_when_too_wide".to_string(),
1241,
);
call_graph.add_function(
test_fn.clone(),
false,
true, 5,
10,
);
assert!(
is_test_function_by_name("vertical_stays_when_too_wide", &call_graph),
"Should match 'test::vertical_stays_when_too_wide' when searching for 'vertical_stays_when_too_wide'"
);
assert!(
is_test_function_by_name("test::vertical_stays_when_too_wide", &call_graph),
"Should match exact name"
);
}
}