use std::collections::HashMap;
use std::str::FromStr;
use std::time::Instant;
use super::config::MonteCarloConfig;
use super::distributions::{parse_distribution, Distribution};
use super::sampler::{Sampler, SamplingMethod};
use super::statistics::{evaluate_threshold, parse_threshold, Histogram, Statistics};
use crate::types::ParsedModel;
#[derive(Debug, Clone)]
pub struct SimulationResult {
pub config: MonteCarloConfig,
pub iterations_completed: usize,
pub execution_time_ms: u64,
pub outputs: HashMap<String, OutputResult>,
pub input_samples: HashMap<String, Vec<f64>>,
}
#[derive(Debug, Clone)]
pub struct OutputResult {
pub variable: String,
pub statistics: Statistics,
pub samples: Vec<f64>,
pub histogram: Histogram,
pub threshold_probabilities: HashMap<String, f64>,
}
pub struct MonteCarloEngine {
config: MonteCarloConfig,
sampler: Sampler,
distributions: HashMap<String, Distribution>,
}
impl MonteCarloEngine {
pub fn new(config: MonteCarloConfig) -> Result<Self, String> {
config.validate()?;
let method = SamplingMethod::from_str(&config.sampling)?;
let sampler = Sampler::new(method, config.seed);
Ok(Self {
config,
sampler,
distributions: HashMap::new(),
})
}
pub fn add_distribution(&mut self, variable: &str, distribution: Distribution) {
self.distributions
.insert(variable.to_string(), distribution);
}
pub fn parse_distributions_from_model(&mut self, model: &ParsedModel) -> Result<(), String> {
for (name, scalar) in &model.scalars {
if let Some(formula) = &scalar.formula {
let formula = formula.trim();
let formula_content = formula.strip_prefix('=').unwrap_or(formula);
if formula_content.starts_with("MC.") {
let dist = parse_distribution(formula_content)?;
self.add_distribution(name, dist);
}
}
}
Ok(())
}
pub fn run(&mut self) -> Result<SimulationResult, String> {
let start = Instant::now();
let n = self.config.iterations;
let mut input_samples: HashMap<String, Vec<f64>> = HashMap::new();
for (var_name, dist) in &self.distributions {
let samples = dist.sample_n(self.sampler.rng_mut(), n);
input_samples.insert(var_name.clone(), samples);
}
let mut outputs = HashMap::new();
for output_config in &self.config.outputs {
let var = &output_config.variable;
let samples = input_samples
.get(var)
.or_else(|| input_samples.get(&format!("scalars.{var}")))
.cloned()
.unwrap_or_else(|| vec![0.0; n]);
let statistics = Statistics::from_samples(&samples);
let histogram = Histogram::from_samples(&samples, 50);
let mut threshold_probabilities = HashMap::new();
if let Some(threshold_str) = &output_config.threshold {
if let Ok((op, value)) = parse_threshold(threshold_str) {
let prob = evaluate_threshold(&samples, &op, value);
threshold_probabilities.insert(threshold_str.clone(), prob);
}
}
outputs.insert(
var.clone(),
OutputResult {
variable: var.clone(),
statistics,
samples,
histogram,
threshold_probabilities,
},
);
}
#[allow(clippy::cast_possible_truncation)]
let execution_time_ms = start.elapsed().as_millis() as u64;
Ok(SimulationResult {
config: self.config.clone(),
iterations_completed: n,
execution_time_ms,
outputs,
input_samples,
})
}
pub fn run_with_evaluator<F>(&mut self, mut evaluator: F) -> Result<SimulationResult, String>
where
F: FnMut(&HashMap<String, f64>) -> HashMap<String, f64>,
{
let start = Instant::now();
let n = self.config.iterations;
let mut input_samples: HashMap<String, Vec<f64>> = HashMap::new();
for (var_name, dist) in &self.distributions {
let samples = dist.sample_n(self.sampler.rng_mut(), n);
input_samples.insert(var_name.clone(), samples);
}
let output_vars: Vec<String> = self
.config
.outputs
.iter()
.map(|o| o.variable.clone())
.collect();
let mut output_samples: HashMap<String, Vec<f64>> = output_vars
.iter()
.map(|v| (v.clone(), Vec::with_capacity(n)))
.collect();
for i in 0..n {
let mut inputs: HashMap<String, f64> = HashMap::new();
for (var, samples) in &input_samples {
inputs.insert(var.clone(), samples[i]);
}
let outputs = evaluator(&inputs);
for var in &output_vars {
let value = outputs.get(var).copied().unwrap_or(0.0);
output_samples.get_mut(var).unwrap().push(value);
}
}
let mut outputs = HashMap::new();
for output_config in &self.config.outputs {
let var = &output_config.variable;
let samples = output_samples.get(var).cloned().unwrap_or_default();
let statistics = Statistics::from_samples(&samples);
let histogram = Histogram::from_samples(&samples, 50);
let mut threshold_probabilities = HashMap::new();
if let Some(threshold_str) = &output_config.threshold {
if let Ok((op, value)) = parse_threshold(threshold_str) {
let prob = evaluate_threshold(&samples, &op, value);
threshold_probabilities.insert(threshold_str.clone(), prob);
}
}
outputs.insert(
var.clone(),
OutputResult {
variable: var.clone(),
statistics,
samples,
histogram,
threshold_probabilities,
},
);
}
#[allow(clippy::cast_possible_truncation)]
let execution_time_ms = start.elapsed().as_millis() as u64;
Ok(SimulationResult {
config: self.config.clone(),
iterations_completed: n,
execution_time_ms,
outputs,
input_samples,
})
}
#[must_use]
pub const fn sampler(&self) -> &Sampler {
&self.sampler
}
pub const fn sampler_mut(&mut self) -> &mut Sampler {
&mut self.sampler
}
}
impl SimulationResult {
#[must_use]
pub fn to_yaml(&self) -> String {
use std::fmt::Write;
let mut output = String::new();
output.push_str("monte_carlo_results:\n");
let _ = writeln!(output, " iterations: {}", self.iterations_completed);
let _ = writeln!(output, " execution_time_ms: {}", self.execution_time_ms);
let _ = writeln!(output, " sampling: {}", self.config.sampling);
if let Some(seed) = self.config.seed {
let _ = writeln!(output, " seed: {seed}");
}
output.push_str("\n outputs:\n");
for (var, result) in &self.outputs {
let _ = writeln!(output, " {var}:");
let _ = writeln!(output, " mean: {:.4}", result.statistics.mean);
let _ = writeln!(output, " median: {:.4}", result.statistics.median);
let _ = writeln!(output, " std_dev: {:.4}", result.statistics.std_dev);
let _ = writeln!(output, " min: {:.4}", result.statistics.min);
let _ = writeln!(output, " max: {:.4}", result.statistics.max);
output.push_str(" percentiles:\n");
for (p, v) in &result.statistics.percentiles {
let _ = writeln!(output, " p{p}: {v:.4}");
}
if !result.threshold_probabilities.is_empty() {
output.push_str(" thresholds:\n");
for (t, prob) in &result.threshold_probabilities {
let _ = writeln!(output, " \"{t}\": {prob:.4}");
}
}
}
output
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
use serde_json::{json, to_string_pretty};
let mut outputs_json = serde_json::Map::new();
for (var, result) in &self.outputs {
let percentiles: serde_json::Map<String, serde_json::Value> = result
.statistics
.percentiles
.iter()
.map(|(p, v)| (format!("p{p}"), json!(v)))
.collect();
let thresholds: serde_json::Map<String, serde_json::Value> = result
.threshold_probabilities
.iter()
.map(|(t, p)| (t.clone(), json!(p)))
.collect();
outputs_json.insert(
var.clone(),
json!({
"mean": result.statistics.mean,
"median": result.statistics.median,
"std_dev": result.statistics.std_dev,
"min": result.statistics.min,
"max": result.statistics.max,
"percentiles": percentiles,
"thresholds": thresholds,
}),
);
}
let result_json = json!({
"monte_carlo_results": {
"iterations": self.iterations_completed,
"execution_time_ms": self.execution_time_ms,
"sampling": self.config.sampling,
"seed": self.config.seed,
"outputs": outputs_json,
}
});
to_string_pretty(&result_json)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::monte_carlo::config::OutputConfig;
fn test_config() -> MonteCarloConfig {
MonteCarloConfig {
enabled: true,
iterations: 10000,
sampling: "latin_hypercube".to_string(),
seed: Some(12345),
outputs: vec![OutputConfig {
variable: "revenue".to_string(),
percentiles: vec![10, 50, 90],
threshold: Some("> 100000".to_string()),
label: None,
}],
correlations: vec![],
}
}
#[test]
fn test_engine_creation() {
let config = test_config();
let engine = MonteCarloEngine::new(config);
assert!(engine.is_ok());
}
#[test]
fn test_add_distribution() {
let config = test_config();
let mut engine = MonteCarloEngine::new(config).unwrap();
let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
engine.add_distribution("revenue", dist);
assert!(engine.distributions.contains_key("revenue"));
}
#[test]
fn test_run_simulation() {
let config = test_config();
let mut engine = MonteCarloEngine::new(config).unwrap();
let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
engine.add_distribution("revenue", dist);
let result = engine.run().unwrap();
assert_eq!(result.iterations_completed, 10000);
assert!(result.input_samples.contains_key("revenue"));
assert!(result.outputs.contains_key("revenue"));
let revenue_result = &result.outputs["revenue"];
assert!((revenue_result.statistics.mean - 100_000.0).abs() < 2_000.0);
assert!(revenue_result.statistics.percentiles.contains_key(&50));
}
#[test]
fn test_run_with_evaluator() {
let config = MonteCarloConfig {
enabled: true,
iterations: 1000,
sampling: "latin_hypercube".to_string(),
seed: Some(42),
outputs: vec![OutputConfig {
variable: "profit".to_string(),
percentiles: vec![10, 50, 90],
threshold: Some("> 0".to_string()),
label: None,
}],
correlations: vec![],
};
let mut engine = MonteCarloEngine::new(config).unwrap();
engine.add_distribution("revenue", Distribution::normal(100.0, 10.0).unwrap());
engine.add_distribution("costs", Distribution::normal(80.0, 5.0).unwrap());
let result = engine
.run_with_evaluator(|inputs| {
let revenue = inputs.get("revenue").copied().unwrap_or(0.0);
let costs = inputs.get("costs").copied().unwrap_or(0.0);
let mut outputs = HashMap::new();
outputs.insert("profit".to_string(), revenue - costs);
outputs
})
.unwrap();
let profit_result = &result.outputs["profit"];
assert!((profit_result.statistics.mean - 20.0).abs() < 3.0);
let prob = profit_result.threshold_probabilities.get("> 0").unwrap();
assert!(*prob > 0.9);
}
#[test]
fn test_output_yaml() {
let config = test_config();
let mut engine = MonteCarloEngine::new(config).unwrap();
let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
engine.add_distribution("revenue", dist);
let result = engine.run().unwrap();
let yaml = result.to_yaml();
assert!(yaml.contains("monte_carlo_results:"));
assert!(yaml.contains("iterations: 10000"));
assert!(yaml.contains("mean:"));
assert!(yaml.contains("percentiles:"));
}
#[test]
fn test_output_json() {
let config = test_config();
let mut engine = MonteCarloEngine::new(config).unwrap();
let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
engine.add_distribution("revenue", dist);
let result = engine.run().unwrap();
let json = result.to_json().unwrap();
assert!(json.contains("\"monte_carlo_results\""));
assert!(json.contains("\"iterations\": 10000"));
assert!(json.contains("\"mean\""));
}
#[test]
fn test_seed_reproducibility() {
let config = test_config();
let mut engine1 = MonteCarloEngine::new(config.clone()).unwrap();
engine1.add_distribution("revenue", Distribution::normal(100.0, 10.0).unwrap());
let result1 = engine1.run().unwrap();
let mut engine2 = MonteCarloEngine::new(config).unwrap();
engine2.add_distribution("revenue", Distribution::normal(100.0, 10.0).unwrap());
let result2 = engine2.run().unwrap();
let samples1 = &result1.input_samples["revenue"];
let samples2 = &result2.input_samples["revenue"];
assert_eq!(samples1, samples2);
}
}