use crate::calibrate::NgramModel;
use std::collections::{HashMap, HashSet};
pub fn extract_dependency_chains(calls: &[(String, String)], max_depth: usize) -> Vec<Vec<String>> {
extract_dependency_chains_bounded(calls, max_depth, usize::MAX)
}
pub fn extract_dependency_chains_bounded(
calls: &[(String, String)],
max_depth: usize,
max_chains: usize,
) -> Vec<Vec<String>> {
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
for (caller, callee) in calls {
adj.entry(caller.as_str())
.or_default()
.push(callee.as_str());
}
let unique_starts: Vec<&str> = {
let mut seen = HashSet::new();
calls
.iter()
.filter_map(|(start, _)| {
if seen.insert(start.as_str()) {
Some(start.as_str())
} else {
None
}
})
.collect()
};
let mut chains = Vec::new();
for start in unique_starts {
if chains.len() >= max_chains {
break;
}
let mut initial_visited = HashSet::new();
initial_visited.insert(start.to_string());
let mut stack: Vec<(Vec<String>, HashSet<String>)> =
vec![(vec![start.to_string()], initial_visited)];
while let Some((chain, visited)) = stack.pop() {
if chains.len() >= max_chains {
break;
}
if chain.len() >= max_depth {
chains.push(chain);
continue;
}
let last = chain.last().expect("chain is non-empty by construction");
let neighbors = adj.get(last.as_str());
let mut extended = false;
if let Some(nbrs) = neighbors {
for nbr in nbrs {
let nbr_string = nbr.to_string();
if !visited.contains(&nbr_string) {
let mut new_chain = chain.clone();
new_chain.push(nbr_string.clone());
let mut new_visited = visited.clone();
new_visited.insert(nbr_string);
stack.push((new_chain, new_visited));
extended = true;
}
}
}
if !extended && chain.len() >= 2 {
chains.push(chain);
}
}
}
chains
}
pub fn chain_surprisal(model: &NgramModel, chain_lines: &[&str]) -> f64 {
if !model.is_confident() || chain_lines.is_empty() {
return 0.0;
}
let mut tokens = Vec::new();
for line in chain_lines {
let line_tokens = NgramModel::tokenize_line(line);
if !line_tokens.is_empty() {
tokens.extend(line_tokens);
tokens.push("<EOL>".to_string());
}
}
model.surprisal(&tokens)
}
pub struct DependencyChainScorer {
pub scores: HashMap<String, f64>,
}
impl DependencyChainScorer {
pub fn new() -> Self {
Self {
scores: HashMap::new(),
}
}
pub fn record_chain(&mut self, chain_qns: &[String], surprisal: f64) {
for qn in chain_qns {
let entry = self.scores.entry(qn.clone()).or_insert(0.0);
if surprisal > *entry {
*entry = surprisal;
}
}
}
pub fn score(&self, function_qn: &str) -> f64 {
self.scores.get(function_qn).copied().unwrap_or(0.0)
}
}
impl Default for DependencyChainScorer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn trained_model() -> NgramModel {
let mut model = NgramModel::new();
for _ in 0..800 {
model.train_on_tokens(&[
"let".to_string(),
"mut".to_string(),
"<ID>".to_string(),
"=".to_string(),
"<NUM>".to_string(),
";".to_string(),
"<EOL>".to_string(),
]);
}
assert!(
model.is_confident(),
"Model should be confident after training"
);
model
}
#[test]
fn test_extract_dependency_chains() {
let calls = vec![
("A".to_string(), "B".to_string()),
("B".to_string(), "C".to_string()),
];
let chains = extract_dependency_chains(&calls, 4);
let has_abc = chains.iter().any(|c| c == &["A", "B", "C"]);
assert!(
has_abc,
"Expected chain [A, B, C] in extracted chains: {:?}",
chains
);
let has_bc = chains.iter().any(|c| c == &["B", "C"]);
assert!(
has_bc,
"Expected chain [B, C] in extracted chains: {:?}",
chains
);
}
#[test]
fn test_extract_chains_handles_cycles() {
let calls = vec![
("A".to_string(), "B".to_string()),
("B".to_string(), "A".to_string()),
];
let chains = extract_dependency_chains(&calls, 10);
assert!(
!chains.is_empty(),
"Should extract at least one chain from a cycle"
);
for chain in &chains {
let mut seen = std::collections::HashSet::new();
for node in chain {
assert!(
seen.insert(node),
"Chain contains duplicate node {:?}: {:?}",
node,
chain
);
}
}
}
#[test]
fn test_chain_surprisal_computation() {
let model = trained_model();
let lines = &["let mut count = 0;", "let mut total = 42;"];
let score = chain_surprisal(&model, lines);
assert!(
score >= 0.0,
"Chain surprisal should be non-negative, got {}",
score
);
}
#[test]
fn test_chain_surprisal_zero_without_confidence() {
let model = NgramModel::new(); assert!(!model.is_confident());
let lines = &["let x = 1;", "let y = 2;"];
let score = chain_surprisal(&model, lines);
assert_eq!(
score, 0.0,
"Unconfident model should return 0.0, got {}",
score
);
}
#[test]
fn test_dependency_chain_scorer_max() {
let mut scorer = DependencyChainScorer::new();
let chain_a = vec!["foo".to_string(), "bar".to_string()];
let chain_b = vec!["foo".to_string(), "baz".to_string()];
scorer.record_chain(&chain_a, 3.5);
scorer.record_chain(&chain_b, 7.2);
assert!(
(scorer.score("foo") - 7.2).abs() < f64::EPSILON,
"Expected max surprisal 7.2 for foo, got {}",
scorer.score("foo")
);
assert!(
(scorer.score("bar") - 3.5).abs() < f64::EPSILON,
"Expected surprisal 3.5 for bar, got {}",
scorer.score("bar")
);
assert!(
(scorer.score("baz") - 7.2).abs() < f64::EPSILON,
"Expected surprisal 7.2 for baz, got {}",
scorer.score("baz")
);
}
#[test]
fn test_score_missing_function() {
let scorer = DependencyChainScorer::new();
assert_eq!(
scorer.score("nonexistent::function"),
0.0,
"Unknown function should return 0.0"
);
}
}