use crate::store::{CallGraph, StoreError};
use crate::Store;
use super::bfs::{reverse_bfs, reverse_bfs_multi_attributed, test_reachability};
use super::types::{FunctionHints, RiskLevel, RiskScore};
use super::DEFAULT_MAX_TEST_SEARCH_DEPTH;
pub const RISK_THRESHOLD_HIGH: f32 = 5.0;
pub const RISK_THRESHOLD_MEDIUM: f32 = 2.0;
pub fn compute_hints_with_graph(
graph: &CallGraph,
test_chunks: &[crate::store::ChunkSummary],
function_name: &str,
prefetched_caller_count: Option<usize>,
) -> FunctionHints {
let _span =
tracing::debug_span!("compute_hints_with_graph", function = function_name).entered();
let caller_count = match prefetched_caller_count {
Some(n) => n,
None => graph
.reverse
.get(function_name)
.map(|v| v.len())
.unwrap_or(0),
};
let ancestors = reverse_bfs(graph, function_name, DEFAULT_MAX_TEST_SEARCH_DEPTH);
let test_count = test_chunks
.iter()
.filter(|t| ancestors.get(&t.name).is_some_and(|&d| d > 0))
.count();
FunctionHints {
caller_count,
test_count,
}
}
pub fn compute_hints(
store: &Store,
function_name: &str,
prefetched_caller_count: Option<usize>,
) -> Result<FunctionHints, StoreError> {
let _span = tracing::info_span!("compute_hints", function = function_name).entered();
let caller_count = match prefetched_caller_count {
Some(n) => n,
None => store.get_callers_full(function_name)?.len(),
};
let graph = store.get_call_graph()?;
let test_chunks = store.find_test_chunks()?;
Ok(compute_hints_with_graph(
&graph,
&test_chunks,
function_name,
Some(caller_count),
))
}
pub fn compute_hints_batch(
graph: &CallGraph,
test_chunks: &[crate::store::ChunkSummary],
names: &[&str],
caller_counts: &std::collections::HashMap<String, u64>,
) -> Vec<FunctionHints> {
let _span = tracing::info_span!("compute_hints_batch", count = names.len()).entered();
let test_names: Vec<&str> = test_chunks.iter().map(|t| t.name.as_str()).collect();
let reachability = test_reachability(graph, &test_names, DEFAULT_MAX_TEST_SEARCH_DEPTH);
names
.iter()
.map(|&name| {
let caller_count = caller_counts
.get(name)
.map(|&c| c as usize)
.unwrap_or_else(|| graph.reverse.get(name).map(|v| v.len()).unwrap_or(0));
let test_count = reachability.get(name).copied().unwrap_or(0);
FunctionHints {
caller_count,
test_count,
}
})
.collect()
}
pub fn compute_risk_batch(
names: &[&str],
graph: &CallGraph,
test_chunks: &[crate::store::ChunkSummary],
) -> Vec<RiskScore> {
let _span = tracing::info_span!("compute_risk_batch", count = names.len()).entered();
let test_names: Vec<&str> = test_chunks.iter().map(|t| t.name.as_str()).collect();
let reachability = test_reachability(graph, &test_names, DEFAULT_MAX_TEST_SEARCH_DEPTH);
names
.iter()
.map(|name| {
let caller_count = graph.reverse.get(*name).map(|v| v.len()).unwrap_or(0);
let test_count = reachability.get(*name).copied().unwrap_or(0);
let test_ratio = if caller_count == 0 {
if test_count > 0 {
1.0
} else {
0.0
}
} else {
(test_count as f32 / caller_count as f32).min(1.0)
};
let score = caller_count as f32 * (1.0 - test_ratio);
let risk_level = if caller_count == 0 && test_count == 0 {
RiskLevel::Medium
} else if score >= RISK_THRESHOLD_HIGH {
RiskLevel::High
} else if score >= RISK_THRESHOLD_MEDIUM {
RiskLevel::Medium
} else {
RiskLevel::Low
};
let blast_radius = match caller_count {
0..=2 => RiskLevel::Low,
3..=10 => RiskLevel::Medium,
_ => RiskLevel::High,
};
RiskScore {
caller_count,
test_count,
test_ratio,
risk_level,
blast_radius,
score,
}
})
.collect()
}
pub fn compute_risk_and_tests(
targets: &[&str],
graph: &CallGraph,
test_chunks: &[crate::store::ChunkSummary],
) -> (Vec<RiskScore>, Vec<super::TestInfo>) {
let _span = tracing::info_span!("compute_risk_and_tests", targets = targets.len()).entered();
let test_names: Vec<&str> = test_chunks.iter().map(|t| t.name.as_str()).collect();
let reachability = test_reachability(graph, &test_names, DEFAULT_MAX_TEST_SEARCH_DEPTH);
let ancestors = reverse_bfs_multi_attributed(graph, targets, DEFAULT_MAX_TEST_SEARCH_DEPTH);
let mut all_tests = Vec::new();
let mut seen_tests = std::collections::HashSet::new();
let mut tests_per_target: Vec<std::collections::HashSet<&str>> =
vec![std::collections::HashSet::new(); targets.len()];
for test in test_chunks {
if let Some(&(depth, source_idx)) = ancestors.get(&test.name) {
if depth > 0 {
if source_idx < targets.len() {
tests_per_target[source_idx].insert(&test.name);
}
if seen_tests.insert((test.name.clone(), test.file.clone())) {
all_tests.push(super::TestInfo {
name: test.name.clone(),
file: test.file.clone(),
line: test.line_start,
call_depth: depth,
});
}
}
}
}
let mut scores = Vec::with_capacity(targets.len());
for (i, &name) in targets.iter().enumerate() {
let caller_count = graph.reverse.get(name).map(|v| v.len()).unwrap_or(0);
let test_count = reachability.get(name).copied().unwrap_or(0);
let test_ratio = if caller_count == 0 {
if test_count > 0 {
1.0
} else {
0.0
}
} else {
(test_count as f32 / caller_count as f32).min(1.0)
};
let score = caller_count as f32 * (1.0 - test_ratio);
let risk_level = if caller_count == 0 && test_count == 0 {
RiskLevel::Medium
} else if score >= RISK_THRESHOLD_HIGH {
RiskLevel::High
} else if score >= RISK_THRESHOLD_MEDIUM {
RiskLevel::Medium
} else {
RiskLevel::Low
};
let blast_radius = match caller_count {
0..=2 => RiskLevel::Low,
3..=10 => RiskLevel::Medium,
_ => RiskLevel::High,
};
let _ = &tests_per_target[i]; scores.push(RiskScore {
caller_count,
test_count,
test_ratio,
risk_level,
blast_radius,
score,
});
}
all_tests.sort_by_key(|t| t.call_depth);
(scores, all_tests)
}
pub fn find_hotspots(graph: &CallGraph, top_n: usize) -> Vec<crate::health::Hotspot> {
let _span = tracing::info_span!("find_hotspots", top_n).entered();
let mut hotspots: Vec<crate::health::Hotspot> = graph
.reverse
.iter()
.map(|(name, callers)| crate::health::Hotspot {
name: name.to_string(),
caller_count: callers.len(),
})
.collect();
hotspots.sort_by(|a, b| b.caller_count.cmp(&a.caller_count));
hotspots.truncate(top_n);
hotspots
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::path::PathBuf;
#[test]
fn test_compute_hints_with_graph_stale_callers() {
let mut reverse = HashMap::new();
reverse.insert(
"target".to_string(),
vec!["ghost_caller".to_string(), "another_ghost".to_string()],
);
let graph = CallGraph::from_string_maps(HashMap::new(), reverse);
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let hints = compute_hints_with_graph(&graph, &test_chunks, "target", None);
assert_eq!(hints.caller_count, 2, "Should count callers from graph");
assert_eq!(hints.test_count, 0, "No test chunks means no tests");
}
#[test]
fn test_compute_hints_with_graph_stale_test_ancestor() {
let mut reverse = HashMap::new();
reverse.insert("target".to_string(), vec!["middle".to_string()]);
let graph = CallGraph::from_string_maps(HashMap::new(), reverse);
let test_chunks = vec![crate::store::ChunkSummary {
id: "test.rs:1:abcd1234".to_string(),
file: PathBuf::from("test.rs"),
language: crate::parser::Language::Rust,
chunk_type: crate::language::ChunkType::Function,
name: "test_fn".to_string(),
signature: "fn test_fn()".to_string(),
content: "#[test] fn test_fn() {}".to_string(),
doc: None,
line_start: 1,
line_end: 5,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
}];
let hints = compute_hints_with_graph(&graph, &test_chunks, "target", None);
assert_eq!(hints.test_count, 0, "Unreachable test should not count");
assert_eq!(hints.caller_count, 1, "middle is a caller");
}
#[test]
fn test_compute_hints_with_graph_prefetched_caller_count() {
let graph = CallGraph::from_string_maps(HashMap::new(), HashMap::new());
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let hints = compute_hints_with_graph(&graph, &test_chunks, "target", Some(99));
assert_eq!(hints.caller_count, 99, "Should use prefetched value");
}
#[test]
fn test_risk_high_many_callers_no_tests() {
let mut reverse = HashMap::new();
reverse.insert(
"target".to_string(),
vec!["a", "b", "c", "d", "e", "f", "g"]
.into_iter()
.map(String::from)
.collect(),
);
let graph = CallGraph::from_string_maps(HashMap::new(), reverse);
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let scores = compute_risk_batch(&["target"], &graph, &test_chunks);
assert_eq!(scores.len(), 1);
assert_eq!(scores[0].risk_level, RiskLevel::High);
assert_eq!(scores[0].caller_count, 7);
assert_eq!(scores[0].test_count, 0);
assert!((scores[0].score - 7.0).abs() < 0.01);
}
#[test]
fn test_risk_low_with_tests() {
let mut reverse = HashMap::new();
reverse.insert(
"target".to_string(),
vec!["a".to_string(), "test_target".to_string()],
);
let mut forward = HashMap::new();
forward.insert("test_target".to_string(), vec!["target".to_string()]);
forward.insert("a".to_string(), vec!["target".to_string()]);
let graph = CallGraph::from_string_maps(forward, reverse);
let test_chunks = vec![crate::store::ChunkSummary {
id: "test_id".to_string(),
file: PathBuf::from("tests/test.rs"),
language: crate::parser::Language::Rust,
chunk_type: crate::language::ChunkType::Function,
name: "test_target".to_string(),
signature: String::new(),
content: String::new(),
doc: None,
line_start: 1,
line_end: 10,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
}];
let scores = compute_risk_batch(&["target"], &graph, &test_chunks);
assert_eq!(scores[0].risk_level, RiskLevel::Low);
assert!((scores[0].score - 1.0).abs() < 0.01);
}
#[test]
fn test_risk_entry_point_no_callers_no_tests() {
let graph = CallGraph::from_string_maps(HashMap::new(), HashMap::new());
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let scores = compute_risk_batch(&["main"], &graph, &test_chunks);
assert_eq!(scores[0].risk_level, RiskLevel::Medium);
assert_eq!(scores[0].caller_count, 0);
assert_eq!(scores[0].test_count, 0);
}
#[test]
fn test_risk_coverage_capped_at_one() {
let mut reverse = HashMap::new();
reverse.insert(
"target".to_string(),
vec![
"a".to_string(),
"test_a".to_string(),
"test_b".to_string(),
"test_c".to_string(),
],
);
let mut forward = HashMap::new();
forward.insert("test_a".to_string(), vec!["target".to_string()]);
forward.insert("test_b".to_string(), vec!["target".to_string()]);
forward.insert("test_c".to_string(), vec!["target".to_string()]);
let graph = CallGraph::from_string_maps(forward, reverse);
let test_chunks = vec![
crate::store::ChunkSummary {
id: "t1".to_string(),
file: PathBuf::from("tests/t.rs"),
language: crate::parser::Language::Rust,
chunk_type: crate::language::ChunkType::Function,
name: "test_a".to_string(),
signature: String::new(),
content: String::new(),
doc: None,
line_start: 1,
line_end: 5,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
},
crate::store::ChunkSummary {
id: "t2".to_string(),
file: PathBuf::from("tests/t.rs"),
language: crate::parser::Language::Rust,
chunk_type: crate::language::ChunkType::Function,
name: "test_b".to_string(),
signature: String::new(),
content: String::new(),
doc: None,
line_start: 6,
line_end: 10,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
},
crate::store::ChunkSummary {
id: "t3".to_string(),
file: PathBuf::from("tests/t.rs"),
language: crate::parser::Language::Rust,
chunk_type: crate::language::ChunkType::Function,
name: "test_c".to_string(),
signature: String::new(),
content: String::new(),
doc: None,
line_start: 11,
line_end: 15,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
},
];
let scores = compute_risk_batch(&["target"], &graph, &test_chunks);
assert!(
scores[0].test_ratio <= 1.0,
"test_ratio should be capped at 1.0, got {}",
scores[0].test_ratio
);
assert_eq!(scores[0].risk_level, RiskLevel::Low);
}
#[test]
fn test_risk_batch_empty_input() {
let graph = CallGraph::from_string_maps(HashMap::new(), HashMap::new());
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let scores = compute_risk_batch(&[], &graph, &test_chunks);
assert!(scores.is_empty());
}
#[test]
fn test_blast_radius_thresholds() {
let mut reverse = HashMap::new();
reverse.insert(
"low_blast".to_string(),
vec!["a", "b"].into_iter().map(String::from).collect(),
);
reverse.insert(
"med_blast".to_string(),
vec!["a", "b", "c"].into_iter().map(String::from).collect(),
);
reverse.insert(
"high_blast".to_string(),
(0..11).map(|i| format!("c{i}")).collect(),
);
let graph = CallGraph::from_string_maps(HashMap::new(), reverse);
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let scores = compute_risk_batch(
&["low_blast", "med_blast", "high_blast"],
&graph,
&test_chunks,
);
assert_eq!(scores[0].blast_radius, RiskLevel::Low);
assert_eq!(scores[1].blast_radius, RiskLevel::Medium);
assert_eq!(scores[2].blast_radius, RiskLevel::High);
}
#[test]
fn test_blast_radius_differs_from_risk() {
let mut reverse = HashMap::new();
let callers: Vec<String> = (0..15).map(|i| format!("caller_{i}")).collect();
let mut all: Vec<String> = callers.clone();
all.push("test_target".to_string());
reverse.insert("target".to_string(), all);
let mut forward = HashMap::new();
forward.insert("test_target".to_string(), vec!["target".to_string()]);
let graph = CallGraph::from_string_maps(forward, reverse);
let test_chunks = vec![crate::store::ChunkSummary {
id: "t1".to_string(),
file: PathBuf::from("tests/t.rs"),
language: crate::parser::Language::Rust,
chunk_type: crate::language::ChunkType::Function,
name: "test_target".to_string(),
signature: String::new(),
content: String::new(),
doc: None,
line_start: 1,
line_end: 5,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
}];
let scores = compute_risk_batch(&["target"], &graph, &test_chunks);
assert_eq!(scores[0].blast_radius, RiskLevel::High);
assert_eq!(scores[0].caller_count, 16);
}
#[test]
fn test_blast_radius_boundary_10_callers_is_medium() {
let mut reverse = HashMap::new();
reverse.insert(
"ten_callers".to_string(),
(0..10).map(|i| format!("c{i}")).collect(),
);
let graph = CallGraph::from_string_maps(HashMap::new(), reverse);
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let scores = compute_risk_batch(&["ten_callers"], &graph, &test_chunks);
assert_eq!(
scores[0].blast_radius,
RiskLevel::Medium,
"10 callers should be Medium blast radius (3..=10)"
);
assert_eq!(scores[0].caller_count, 10);
}
#[test]
fn test_risk_score_formula_many_callers_no_tests() {
let mut reverse = HashMap::new();
reverse.insert(
"target".to_string(),
(0..6).map(|i| format!("c{i}")).collect(),
);
let graph = CallGraph::from_string_maps(HashMap::new(), reverse);
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let scores = compute_risk_batch(&["target"], &graph, &test_chunks);
assert_eq!(scores[0].risk_level, RiskLevel::High);
assert!((scores[0].score - 6.0).abs() < 0.01);
assert_eq!(scores[0].test_count, 0);
assert!((scores[0].test_ratio - 0.0).abs() < 0.01);
}
#[test]
fn test_risk_medium_boundary() {
let mut reverse = HashMap::new();
reverse.insert(
"target".to_string(),
vec!["a", "b", "c"].into_iter().map(String::from).collect(),
);
let graph = CallGraph::from_string_maps(HashMap::new(), reverse);
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let scores = compute_risk_batch(&["target"], &graph, &test_chunks);
assert_eq!(scores[0].risk_level, RiskLevel::Medium);
assert!((scores[0].score - 3.0).abs() < 0.01);
}
#[test]
fn test_risk_low_below_medium_threshold() {
let mut reverse = HashMap::new();
reverse.insert("target".to_string(), vec!["a".to_string()]);
let graph = CallGraph::from_string_maps(HashMap::new(), reverse);
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let scores = compute_risk_batch(&["target"], &graph, &test_chunks);
assert_eq!(scores[0].risk_level, RiskLevel::Low);
assert!((scores[0].score - 1.0).abs() < 0.01);
}
#[test]
fn test_risk_zero_callers_with_test_is_low() {
let mut forward = HashMap::new();
forward.insert("test_fn".to_string(), vec!["target".to_string()]);
let graph = CallGraph::from_string_maps(forward, HashMap::new());
let test_chunks = vec![crate::store::ChunkSummary {
id: "t1".to_string(),
file: PathBuf::from("tests/t.rs"),
language: crate::parser::Language::Rust,
chunk_type: crate::language::ChunkType::Function,
name: "test_fn".to_string(),
signature: String::new(),
content: String::new(),
doc: None,
line_start: 1,
line_end: 5,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
}];
let scores = compute_risk_batch(&["target"], &graph, &test_chunks);
assert_eq!(scores[0].risk_level, RiskLevel::Low);
assert_eq!(scores[0].caller_count, 0);
assert_eq!(scores[0].test_count, 1);
}
#[test]
fn test_blast_radius_boundary_0_callers() {
let graph = CallGraph::from_string_maps(HashMap::new(), HashMap::new());
let test_chunks: Vec<crate::store::ChunkSummary> = Vec::new();
let scores = compute_risk_batch(&["orphan"], &graph, &test_chunks);
assert_eq!(
scores[0].blast_radius,
RiskLevel::Low,
"0 callers should be Low blast radius (0..=2)"
);
}
#[test]
fn test_find_hotspots() {
let mut reverse = HashMap::new();
reverse.insert(
"hot".to_string(),
vec!["a", "b", "c"].into_iter().map(String::from).collect(),
);
reverse.insert(
"warm".to_string(),
vec!["a", "b"].into_iter().map(String::from).collect(),
);
reverse.insert("cold".to_string(), vec!["a".to_string()]);
let graph = CallGraph::from_string_maps(HashMap::new(), reverse);
let hotspots = find_hotspots(&graph, 2);
assert_eq!(hotspots.len(), 2);
assert_eq!(hotspots[0].name, "hot");
assert_eq!(hotspots[0].caller_count, 3);
assert_eq!(hotspots[1].name, "warm");
assert_eq!(hotspots[1].caller_count, 2);
}
}