use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScenarioDefinition {
pub probability: f64,
#[serde(default)]
pub description: String,
#[serde(default)]
pub scalars: HashMap<String, ScalarOverride>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ScalarOverride {
Value(f64),
Formula { formula: String },
}
impl ScalarOverride {
#[must_use]
pub const fn as_value(&self) -> Option<f64> {
match self {
Self::Value(v) => Some(*v),
Self::Formula { .. } => None,
}
}
#[must_use]
pub fn as_formula(&self) -> Option<&str> {
match self {
Self::Value(_) => None,
Self::Formula { formula } => Some(formula),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScenarioConfig {
#[serde(default)]
pub scenarios: HashMap<String, ScenarioDefinition>,
}
impl ScenarioConfig {
#[must_use]
pub fn new() -> Self {
Self {
scenarios: HashMap::new(),
}
}
pub fn add_scenario(&mut self, name: &str, scenario: ScenarioDefinition) -> &mut Self {
self.scenarios.insert(name.to_string(), scenario);
self
}
pub fn validate(&self) -> Result<(), String> {
const TOLERANCE: f64 = 0.001;
if self.scenarios.is_empty() {
return Err("No scenarios defined".to_string());
}
let total_prob: f64 = self.scenarios.values().map(|s| s.probability).sum();
if (total_prob - 1.0).abs() > TOLERANCE {
return Err(format!(
"Scenario probabilities must sum to 1.0, got {total_prob:.4}"
));
}
for (name, scenario) in &self.scenarios {
if scenario.probability < 0.0 || scenario.probability > 1.0 {
return Err(format!(
"Scenario '{}' probability must be between 0 and 1, got {}",
name, scenario.probability
));
}
}
Ok(())
}
pub fn scenario_names(&self) -> Vec<&str> {
self.scenarios
.keys()
.map(std::string::String::as_str)
.collect()
}
#[must_use]
pub fn has_scenario(&self, name: &str) -> bool {
self.scenarios.contains_key(name)
}
#[must_use]
pub fn get_scenario(&self, name: &str) -> Option<&ScenarioDefinition> {
self.scenarios.get(name)
}
}
impl ScenarioDefinition {
#[must_use]
pub fn new(probability: f64) -> Self {
Self {
probability,
description: String::new(),
scalars: HashMap::new(),
}
}
#[must_use]
pub fn with_description(mut self, description: &str) -> Self {
self.description = description.to_string();
self
}
#[must_use]
pub fn with_scalar(mut self, name: &str, value: f64) -> Self {
self.scalars
.insert(name.to_string(), ScalarOverride::Value(value));
self
}
#[must_use]
pub fn with_formula(mut self, name: &str, formula: &str) -> Self {
self.scalars.insert(
name.to_string(),
ScalarOverride::Formula {
formula: formula.to_string(),
},
);
self
}
}
#[cfg(test)]
mod config_tests {
use super::*;
#[test]
fn test_scenario_config_validation() {
let mut config = ScenarioConfig::new();
assert!(config.validate().is_err());
config.add_scenario(
"base",
ScenarioDefinition::new(0.5)
.with_description("Base case")
.with_scalar("revenue_growth", 0.05),
);
config.add_scenario(
"bull",
ScenarioDefinition::new(0.3)
.with_description("Bull case")
.with_scalar("revenue_growth", 0.15),
);
config.add_scenario(
"bear",
ScenarioDefinition::new(0.2)
.with_description("Bear case")
.with_scalar("revenue_growth", -0.10),
);
assert!(config.validate().is_ok());
}
#[test]
fn test_probabilities_must_sum_to_one() {
let mut config = ScenarioConfig::new();
config.add_scenario("a", ScenarioDefinition::new(0.5));
config.add_scenario("b", ScenarioDefinition::new(0.3));
let result = config.validate();
assert!(result.is_err());
assert!(result.unwrap_err().contains("sum to 1.0"));
}
#[test]
fn test_scalar_override_types() {
let scenario = ScenarioDefinition::new(0.5)
.with_scalar("fixed", 100.0)
.with_formula("distribution", "=MC.Normal(1000, 100)");
assert_eq!(
scenario.scalars.get("fixed").unwrap().as_value(),
Some(100.0)
);
assert_eq!(
scenario.scalars.get("distribution").unwrap().as_formula(),
Some("=MC.Normal(1000, 100)")
);
}
}