use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SensitivityAnalysis {
pub correlations: HashMap<String, HashMap<String, f64>>,
pub tornado_data: HashMap<String, Vec<TornadoBar>>,
}
#[derive(Debug, Clone)]
pub struct TornadoBar {
pub variable: String,
pub correlation: f64,
pub impact: f64,
}
impl SensitivityAnalysis {
#[must_use]
pub fn compute(
input_samples: &HashMap<String, Vec<f64>>,
output_samples: &HashMap<String, Vec<f64>>,
) -> Self {
let mut correlations: HashMap<String, HashMap<String, f64>> = HashMap::new();
let mut tornado_data: HashMap<String, Vec<TornadoBar>> = HashMap::new();
for (output_name, output_values) in output_samples {
let mut output_correlations: HashMap<String, f64> = HashMap::new();
let mut bars: Vec<TornadoBar> = Vec::new();
for (input_name, input_values) in input_samples {
let rho = spearman_correlation(input_values, output_values);
output_correlations.insert(input_name.clone(), rho);
bars.push(TornadoBar {
variable: input_name.clone(),
correlation: rho,
impact: rho.abs(),
});
}
bars.sort_by(|a, b| {
b.impact
.partial_cmp(&a.impact)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (input_name, rho) in output_correlations {
correlations
.entry(input_name)
.or_default()
.insert(output_name.clone(), rho);
}
tornado_data.insert(output_name.clone(), bars);
}
Self {
correlations,
tornado_data,
}
}
#[must_use]
pub fn top_drivers(&self, output: &str, n: usize) -> Vec<&TornadoBar> {
self.tornado_data
.get(output)
.map(|bars| bars.iter().take(n).collect())
.unwrap_or_default()
}
#[must_use]
pub fn get_correlation(&self, input: &str, output: &str) -> Option<f64> {
self.correlations.get(input)?.get(output).copied()
}
#[must_use]
pub fn to_tornado_json(&self, output: &str) -> Option<String> {
let bars = self.tornado_data.get(output)?;
let json_bars: Vec<String> = bars
.iter()
.map(|bar| {
format!(
r#"{{"variable":"{}","correlation":{:.4},"impact":{:.4}}}"#,
bar.variable, bar.correlation, bar.impact
)
})
.collect();
Some(format!(
r#"{{"output":"{}","sensitivity":[{}]}}"#,
output,
json_bars.join(",")
))
}
}
#[must_use]
pub fn spearman_correlation(x: &[f64], y: &[f64]) -> f64 {
if x.len() != y.len() || x.is_empty() {
return 0.0;
}
let x_ranks = compute_ranks(x);
let y_ranks = compute_ranks(y);
pearson_correlation(&x_ranks, &y_ranks)
}
fn compute_ranks(values: &[f64]) -> Vec<f64> {
let n = values.len();
let mut indexed: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut ranks = vec![0.0; n];
let mut i = 0;
while i < n {
let mut j = i;
while j < n && (indexed[j].1 - indexed[i].1).abs() < 1e-10 {
j += 1;
}
let avg_rank = (i + j + 1) as f64 / 2.0;
for item in indexed.iter().take(j).skip(i) {
ranks[item.0] = avg_rank;
}
i = j;
}
ranks
}
fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
let n_samples = x.len() as f64;
let mean_x: f64 = x.iter().sum::<f64>() / n_samples;
let mean_y: f64 = y.iter().sum::<f64>() / n_samples;
let mut num = 0.0;
let mut denom_x = 0.0;
let mut denom_y = 0.0;
for (xi, yi) in x.iter().zip(y.iter()) {
let dx = xi - mean_x;
let dy = yi - mean_y;
num += dx * dy;
denom_x += dx * dx;
denom_y += dy * dy;
}
let denom = (denom_x * denom_y).sqrt();
if denom < 1e-10 {
return 0.0;
}
num / denom
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spearman_perfect_positive() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let rho = spearman_correlation(&x, &y);
assert!((rho - 1.0).abs() < 1e-10, "Expected 1.0, got {rho}");
}
#[test]
fn test_spearman_perfect_negative() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![50.0, 40.0, 30.0, 20.0, 10.0];
let rho = spearman_correlation(&x, &y);
assert!((rho + 1.0).abs() < 1e-10, "Expected -1.0, got {rho}");
}
#[test]
fn test_spearman_moderate_correlation() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![3.0, 1.0, 4.0, 2.0, 5.0];
let rho = spearman_correlation(&x, &y);
assert!((rho - 0.5).abs() < 0.1, "Expected ~0.5, got {rho}");
}
#[test]
fn test_spearman_with_ties() {
let x = vec![1.0, 2.0, 2.0, 4.0, 5.0]; let y = vec![10.0, 20.0, 25.0, 40.0, 50.0];
let rho = spearman_correlation(&x, &y);
assert!(rho > 0.9, "Expected high positive correlation, got {rho}");
}
#[test]
fn test_compute_ranks_simple() {
let values = vec![3.0, 1.0, 4.0, 1.0, 5.0];
let ranks = compute_ranks(&values);
assert!((ranks[0] - 3.0).abs() < 1e-10); assert!((ranks[1] - 1.5).abs() < 1e-10); assert!((ranks[2] - 4.0).abs() < 1e-10); assert!((ranks[3] - 1.5).abs() < 1e-10); assert!((ranks[4] - 5.0).abs() < 1e-10); }
#[test]
fn test_sensitivity_analysis() {
let mut inputs: HashMap<String, Vec<f64>> = HashMap::new();
inputs.insert(
"revenue".to_string(),
vec![100.0, 110.0, 120.0, 130.0, 140.0],
);
inputs.insert("costs".to_string(), vec![50.0, 55.0, 52.0, 58.0, 60.0]);
let mut outputs: HashMap<String, Vec<f64>> = HashMap::new();
outputs.insert("profit".to_string(), vec![50.0, 55.0, 68.0, 72.0, 80.0]);
let analysis = SensitivityAnalysis::compute(&inputs, &outputs);
let revenue_impact = analysis.get_correlation("revenue", "profit").unwrap();
let costs_impact = analysis.get_correlation("costs", "profit").unwrap();
assert!(
revenue_impact.abs() > costs_impact.abs(),
"Revenue impact {revenue_impact} should be greater than costs impact {costs_impact}"
);
}
#[test]
fn test_tornado_ordering() {
let mut inputs: HashMap<String, Vec<f64>> = HashMap::new();
inputs.insert("high_impact".to_string(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
inputs.insert("medium_impact".to_string(), vec![1.0, 2.5, 2.0, 4.0, 5.0]);
inputs.insert("no_impact".to_string(), vec![3.0, 1.0, 4.0, 2.0, 5.0]);
let mut outputs: HashMap<String, Vec<f64>> = HashMap::new();
outputs.insert("result".to_string(), vec![10.0, 20.0, 30.0, 40.0, 50.0]);
let analysis = SensitivityAnalysis::compute(&inputs, &outputs);
let top = analysis.top_drivers("result", 3);
assert_eq!(top.len(), 3);
assert_eq!(top[0].variable, "high_impact");
assert!(
top[0].correlation > 0.99,
"Expected ~1.0, got {}",
top[0].correlation
);
assert_eq!(top[1].variable, "medium_impact");
assert!(
top[1].correlation > 0.8 && top[1].correlation < 1.0,
"Expected high but imperfect correlation, got {}",
top[1].correlation
);
}
#[test]
fn test_tornado_json() {
let mut inputs: HashMap<String, Vec<f64>> = HashMap::new();
inputs.insert("a".to_string(), vec![1.0, 2.0, 3.0]);
inputs.insert("b".to_string(), vec![3.0, 2.0, 1.0]);
let mut outputs: HashMap<String, Vec<f64>> = HashMap::new();
outputs.insert("out".to_string(), vec![10.0, 20.0, 30.0]);
let analysis = SensitivityAnalysis::compute(&inputs, &outputs);
let json = analysis.to_tornado_json("out").unwrap();
assert!(json.contains("\"output\":\"out\""));
assert!(json.contains("\"sensitivity\":"));
assert!(json.contains("\"variable\":"));
}
}