use crate::compat::Instant;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use serde::{Deserialize, Serialize};
use crate::error::CorpFinanceError;
use crate::types::*;
use crate::CorpFinanceResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensitivityInput {
pub base_inputs: serde_json::Value,
pub variable_1: SensitivityVariable,
pub variable_2: SensitivityVariable,
pub output_metric: String,
pub compute_fn: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensitivityOutput {
pub variable_1_name: String,
pub variable_2_name: String,
pub variable_1_values: Vec<Decimal>,
pub variable_2_values: Vec<Decimal>,
pub output_metric: String,
pub matrix: Vec<Vec<Decimal>>,
pub base_case_value: Decimal,
pub base_case_position: (usize, usize),
}
fn generate_sweep_values(var: &SensitivityVariable) -> CorpFinanceResult<Vec<Decimal>> {
if var.step <= Decimal::ZERO {
return Err(CorpFinanceError::InvalidInput {
field: format!("variable:{}", var.name),
reason: "Step must be positive".into(),
});
}
if var.min > var.max {
return Err(CorpFinanceError::InvalidInput {
field: format!("variable:{}", var.name),
reason: "Min must be <= max".into(),
});
}
let mut values = Vec::new();
let mut current = var.min;
while current <= var.max {
values.push(current);
current += var.step;
}
if let Some(&last) = values.last() {
if last < var.max {
values.push(var.max);
}
}
if values.is_empty() {
values.push(var.min);
}
Ok(values)
}
fn closest_index(values: &[Decimal], target: Decimal) -> usize {
values
.iter()
.enumerate()
.min_by_key(|(_, v)| (**v - target).abs())
.map(|(i, _)| i)
.unwrap_or(0)
}
pub fn build_sensitivity_grid(
input: &SensitivityInput,
) -> CorpFinanceResult<ComputationOutput<SensitivityOutput>> {
let start = Instant::now();
let warnings: Vec<String> = Vec::new();
let v1_values = generate_sweep_values(&input.variable_1)?;
let v2_values = generate_sweep_values(&input.variable_2)?;
let matrix = vec![vec![Decimal::ZERO; v2_values.len()]; v1_values.len()];
let mid1 = (input.variable_1.min + input.variable_1.max) / dec!(2);
let mid2 = (input.variable_2.min + input.variable_2.max) / dec!(2);
let base_row = closest_index(&v1_values, mid1);
let base_col = closest_index(&v2_values, mid2);
let output = SensitivityOutput {
variable_1_name: input.variable_1.name.clone(),
variable_2_name: input.variable_2.name.clone(),
variable_1_values: v1_values,
variable_2_values: v2_values,
output_metric: input.output_metric.clone(),
matrix,
base_case_value: Decimal::ZERO,
base_case_position: (base_row, base_col),
};
let elapsed = start.elapsed().as_micros() as u64;
Ok(with_metadata(
"2-Way Sensitivity Analysis Grid",
&serde_json::json!({
"variable_1": input.variable_1.name,
"variable_2": input.variable_2.name,
"output_metric": input.output_metric,
"compute_fn": input.compute_fn,
}),
warnings,
elapsed,
output,
))
}
pub fn evaluate_sensitivity<F>(
input: &SensitivityInput,
eval_fn: F,
) -> CorpFinanceResult<ComputationOutput<SensitivityOutput>>
where
F: Fn(Decimal, Decimal) -> CorpFinanceResult<Decimal>,
{
let start = Instant::now();
let mut warnings: Vec<String> = Vec::new();
let v1_values = generate_sweep_values(&input.variable_1)?;
let v2_values = generate_sweep_values(&input.variable_2)?;
let mut matrix = Vec::with_capacity(v1_values.len());
for v1 in &v1_values {
let mut row = Vec::with_capacity(v2_values.len());
for v2 in &v2_values {
match eval_fn(*v1, *v2) {
Ok(val) => row.push(val),
Err(e) => {
warnings.push(format!("Evaluation failed at ({v1}, {v2}): {e}"));
row.push(Decimal::ZERO);
}
}
}
matrix.push(row);
}
let mid1 = (input.variable_1.min + input.variable_1.max) / dec!(2);
let mid2 = (input.variable_2.min + input.variable_2.max) / dec!(2);
let base_row = closest_index(&v1_values, mid1);
let base_col = closest_index(&v2_values, mid2);
let base_case_value = matrix[base_row][base_col];
let output = SensitivityOutput {
variable_1_name: input.variable_1.name.clone(),
variable_2_name: input.variable_2.name.clone(),
variable_1_values: v1_values,
variable_2_values: v2_values,
output_metric: input.output_metric.clone(),
matrix,
base_case_value,
base_case_position: (base_row, base_col),
};
let elapsed = start.elapsed().as_micros() as u64;
Ok(with_metadata(
"2-Way Sensitivity Analysis (Evaluated)",
&serde_json::json!({
"variable_1": input.variable_1.name,
"variable_2": input.variable_2.name,
"output_metric": input.output_metric,
"compute_fn": input.compute_fn,
}),
warnings,
elapsed,
output,
))
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
fn sample_input() -> SensitivityInput {
SensitivityInput {
base_inputs: serde_json::json!({}),
variable_1: SensitivityVariable {
name: "WACC".into(),
min: dec!(0.08),
max: dec!(0.12),
step: dec!(0.01),
},
variable_2: SensitivityVariable {
name: "Growth Rate".into(),
min: dec!(0.01),
max: dec!(0.05),
step: dec!(0.01),
},
output_metric: "Enterprise Value".into(),
compute_fn: "dcf".into(),
}
}
#[test]
fn test_build_grid() {
let result = build_sensitivity_grid(&sample_input()).unwrap();
let out = &result.result;
assert_eq!(out.variable_1_name, "WACC");
assert_eq!(out.variable_2_name, "Growth Rate");
assert_eq!(out.variable_1_values.len(), 5);
assert_eq!(out.variable_2_values.len(), 5);
assert_eq!(out.matrix.len(), 5);
assert_eq!(out.matrix[0].len(), 5);
}
#[test]
fn test_evaluate_sensitivity() {
let input = sample_input();
let result = evaluate_sensitivity(&input, |wacc, growth| {
let spread = wacc - growth;
if spread.is_zero() {
return Err(CorpFinanceError::DivisionByZero {
context: "test model".into(),
});
}
Ok(dec!(1000) / spread)
})
.unwrap();
let out = &result.result;
let col = 0; for i in 0..out.matrix.len() - 1 {
assert!(out.matrix[i][col] > out.matrix[i + 1][col]);
}
let row = 0; for j in 0..out.matrix[0].len() - 1 {
assert!(out.matrix[row][j] < out.matrix[row][j + 1]);
}
}
#[test]
fn test_sweep_values() {
let var = SensitivityVariable {
name: "test".into(),
min: dec!(1),
max: dec!(5),
step: dec!(1),
};
let vals = generate_sweep_values(&var).unwrap();
assert_eq!(vals, vec![dec!(1), dec!(2), dec!(3), dec!(4), dec!(5)]);
}
#[test]
fn test_sweep_with_non_exact_step() {
let var = SensitivityVariable {
name: "test".into(),
min: dec!(0),
max: dec!(1),
step: dec!(0.3),
};
let vals = generate_sweep_values(&var).unwrap();
assert_eq!(vals.len(), 5);
assert_eq!(*vals.last().unwrap(), dec!(1));
}
#[test]
fn test_invalid_step() {
let input = SensitivityInput {
base_inputs: serde_json::json!({}),
variable_1: SensitivityVariable {
name: "bad".into(),
min: dec!(0),
max: dec!(1),
step: dec!(0),
},
variable_2: SensitivityVariable {
name: "ok".into(),
min: dec!(0),
max: dec!(1),
step: dec!(0.5),
},
output_metric: "test".into(),
compute_fn: "test".into(),
};
assert!(build_sensitivity_grid(&input).is_err());
}
#[test]
fn test_base_case_position() {
let result = build_sensitivity_grid(&sample_input()).unwrap();
let out = &result.result;
assert_eq!(out.base_case_position, (2, 2));
}
}