use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::gridsearch::{ParameterGrid, ParameterPoint};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensitivityAnalysis {
pub importance: HashMap<String, f64>,
pub importance_rank: HashMap<String, usize>,
pub oat_curves: HashMap<String, Vec<(f64, f64)>>,
pub interactions: HashMap<String, f64>,
}
pub fn ablation_study<F>(
baseline: &ParameterPoint,
baseline_score: f64,
evaluator: F,
) -> HashMap<String, f64>
where
F: Fn(&ParameterPoint) -> f64,
{
let grid = ParameterGrid::default();
let mut importance = HashMap::new();
for name in grid.param_names() {
let ablated = ablate_param(baseline, &name);
let ablated_score = evaluator(&ablated);
let impact = baseline_score - ablated_score;
importance.insert(name, impact.max(0.0)); }
importance
}
fn ablate_param(baseline: &ParameterPoint, param: &str) -> ParameterPoint {
let mut ablated = baseline.clone();
match param {
"boost_mentioned_ident" => ablated.boost_mentioned_ident = 1.0,
"boost_mentioned_file" => ablated.boost_mentioned_file = 1.0,
"boost_chat_file" => ablated.boost_chat_file = 1.0,
"boost_temporal_coupling" => ablated.boost_temporal_coupling = 1.0,
"boost_focus_expansion" => ablated.boost_focus_expansion = 1.0,
"pagerank_chat_multiplier" => ablated.pagerank_chat_multiplier = 1.0,
"pagerank_alpha" => ablated.pagerank_alpha = 0.85,
"depth_weight_root" => ablated.depth_weight_root = 1.0,
"depth_weight_moderate" => ablated.depth_weight_moderate = 1.0,
"depth_weight_deep" => ablated.depth_weight_deep = 1.0,
"depth_weight_vendor" => ablated.depth_weight_vendor = 1.0,
"git_recency_decay_days" => ablated.git_recency_decay_days = 30.0,
"git_recency_max_boost" => ablated.git_recency_max_boost = 1.0,
"git_churn_threshold" => ablated.git_churn_threshold = 10.0,
"git_churn_max_boost" => ablated.git_churn_max_boost = 1.0,
"focus_decay" => ablated.focus_decay = 0.5,
"focus_max_hops" => ablated.focus_max_hops = 1.0,
_ => {} }
ablated
}
pub fn oat_sensitivity<F>(
baseline: &ParameterPoint,
evaluator: F,
n_points: usize,
) -> HashMap<String, Vec<(f64, f64)>>
where
F: Fn(&ParameterPoint) -> f64,
{
let grid = ParameterGrid::default();
let mut curves = HashMap::new();
for name in grid.param_names() {
let range = &grid.ranges[&name];
let mut curve = Vec::with_capacity(n_points);
for i in 0..n_points {
let t = i as f64 / (n_points - 1) as f64;
let value = range.decode(t);
let point = set_param(baseline, &name, value);
let score = evaluator(&point);
curve.push((value, score));
}
curves.insert(name, curve);
}
curves
}
fn set_param(baseline: &ParameterPoint, param: &str, value: f64) -> ParameterPoint {
let mut point = baseline.clone();
match param {
"pagerank_alpha" => point.pagerank_alpha = value,
"pagerank_chat_multiplier" => point.pagerank_chat_multiplier = value,
"depth_weight_root" => point.depth_weight_root = value,
"depth_weight_moderate" => point.depth_weight_moderate = value,
"depth_weight_deep" => point.depth_weight_deep = value,
"depth_weight_vendor" => point.depth_weight_vendor = value,
"boost_mentioned_ident" => point.boost_mentioned_ident = value,
"boost_mentioned_file" => point.boost_mentioned_file = value,
"boost_chat_file" => point.boost_chat_file = value,
"boost_temporal_coupling" => point.boost_temporal_coupling = value,
"boost_focus_expansion" => point.boost_focus_expansion = value,
"git_recency_decay_days" => point.git_recency_decay_days = value,
"git_recency_max_boost" => point.git_recency_max_boost = value,
"git_churn_threshold" => point.git_churn_threshold = value,
"git_churn_max_boost" => point.git_churn_max_boost = value,
"focus_decay" => point.focus_decay = value,
"focus_max_hops" => point.focus_max_hops = value,
_ => {}
}
point
}
pub fn detect_interactions<F>(
baseline: &ParameterPoint,
baseline_score: f64,
evaluator: F,
ablation: &HashMap<String, f64>,
) -> HashMap<String, f64>
where
F: Fn(&ParameterPoint) -> f64,
{
let grid = ParameterGrid::default();
let params = grid.param_names();
let mut interactions = HashMap::new();
let mut sorted_params: Vec<_> = params
.iter()
.map(|p| (p.clone(), ablation.get(p).copied().unwrap_or(0.0)))
.collect();
sorted_params.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_params: Vec<_> = sorted_params
.into_iter()
.take(6) .map(|(p, _)| p)
.collect();
for (i, param_a) in top_params.iter().enumerate() {
for param_b in top_params.iter().skip(i + 1) {
let effect_a = ablation.get(param_a).copied().unwrap_or(0.0);
let effect_b = ablation.get(param_b).copied().unwrap_or(0.0);
let ablated_both = ablate_param(&ablate_param(baseline, param_a), param_b);
let score_both = evaluator(&ablated_both);
let effect_both = baseline_score - score_both;
let expected_additive = effect_a + effect_b;
let interaction = effect_both - expected_additive;
if interaction.abs() > 0.01 {
let key = format!("{}|{}", param_a, param_b);
interactions.insert(key, interaction);
}
}
}
interactions
}
pub fn full_analysis<F>(baseline: &ParameterPoint, evaluator: F) -> SensitivityAnalysis
where
F: Fn(&ParameterPoint) -> f64 + Copy,
{
let baseline_score = evaluator(baseline);
let importance = ablation_study(baseline, baseline_score, evaluator);
let mut sorted: Vec<_> = importance.iter().collect();
sorted.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
let importance_rank: HashMap<_, _> = sorted
.iter()
.enumerate()
.map(|(rank, (name, _))| ((*name).clone(), rank + 1))
.collect();
let oat_curves = oat_sensitivity(baseline, evaluator, 7);
let interactions = detect_interactions(baseline, baseline_score, evaluator, &importance);
SensitivityAnalysis {
importance,
importance_rank,
oat_curves,
interactions,
}
}
pub fn print_summary(analysis: &SensitivityAnalysis) {
println!("\n=== Parameter Importance (Ablation) ===\n");
let mut sorted: Vec<_> = analysis.importance.iter().collect();
sorted.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
for (name, importance) in sorted.iter().take(10) {
let rank = analysis.importance_rank.get(*name).unwrap_or(&0);
let bar_len = (*importance * 50.0).round() as usize;
let bar = "█".repeat(bar_len.min(50));
println!("{:>30}: {:>6.4} [{}] {}", name, importance, rank, bar);
}
if !analysis.interactions.is_empty() {
println!("\n=== Significant Interactions ===\n");
let mut sorted_interactions: Vec<_> = analysis.interactions.iter().collect();
sorted_interactions.sort_by(|a, b| {
b.1.abs()
.partial_cmp(&a.1.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
for (key, strength) in sorted_interactions.iter().take(5) {
let parts: Vec<&str> = key.split('|').collect();
let (a, b) = if parts.len() == 2 {
(parts[0], parts[1])
} else {
(key.as_str(), "?")
};
let direction = if **strength > 0.0 {
"synergistic"
} else {
"antagonistic"
};
println!(" {} × {} = {:+.4} ({})", a, b, strength, direction);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ablate_boost() {
let baseline = ParameterPoint::default();
let ablated = ablate_param(&baseline, "boost_mentioned_ident");
assert!((ablated.boost_mentioned_ident - 1.0).abs() < 1e-6);
assert!((ablated.pagerank_alpha - baseline.pagerank_alpha).abs() < 1e-6);
}
#[test]
fn test_set_param() {
let baseline = ParameterPoint::default();
let modified = set_param(&baseline, "pagerank_alpha", 0.75);
assert!((modified.pagerank_alpha - 0.75).abs() < 1e-6);
}
#[test]
fn test_ablation_study() {
let baseline = ParameterPoint::default();
let evaluator =
|p: &ParameterPoint| p.boost_mentioned_ident * 0.1 + p.boost_chat_file * 0.05;
let importance = ablation_study(&baseline, evaluator(&baseline), evaluator);
assert!(importance["boost_mentioned_ident"] > 0.0);
}
#[test]
fn test_oat_sensitivity() {
let baseline = ParameterPoint::default();
let evaluator = |p: &ParameterPoint| p.pagerank_alpha;
let curves = oat_sensitivity(&baseline, evaluator, 5);
assert!(curves.contains_key("pagerank_alpha"));
let alpha_curve = &curves["pagerank_alpha"];
assert_eq!(alpha_curve.len(), 5);
for i in 1..alpha_curve.len() {
assert!(alpha_curve[i].1 >= alpha_curve[i - 1].1);
}
}
}