1use std::collections::HashMap;
10use std::str::FromStr;
11use std::time::Instant;
12
13use super::config::MonteCarloConfig;
14use super::distributions::{parse_distribution, Distribution};
15use super::sampler::{Sampler, SamplingMethod};
16use super::statistics::{evaluate_threshold, parse_threshold, Histogram, Statistics};
17use crate::types::ParsedModel;
18
19#[derive(Debug, Clone)]
21pub struct SimulationResult {
22 pub config: MonteCarloConfig,
24 pub iterations_completed: usize,
26 pub execution_time_ms: u64,
28 pub outputs: HashMap<String, OutputResult>,
30 pub input_samples: HashMap<String, Vec<f64>>,
32}
33
34#[derive(Debug, Clone)]
36pub struct OutputResult {
37 pub variable: String,
39 pub statistics: Statistics,
41 pub samples: Vec<f64>,
43 pub histogram: Histogram,
45 pub threshold_probabilities: HashMap<String, f64>,
47}
48
49pub struct MonteCarloEngine {
51 config: MonteCarloConfig,
52 sampler: Sampler,
53 distributions: HashMap<String, Distribution>,
54}
55
56impl MonteCarloEngine {
57 pub fn new(config: MonteCarloConfig) -> Result<Self, String> {
63 config.validate()?;
64
65 let method = SamplingMethod::from_str(&config.sampling)?;
66 let sampler = Sampler::new(method, config.seed);
67
68 Ok(Self {
69 config,
70 sampler,
71 distributions: HashMap::new(),
72 })
73 }
74
75 pub fn add_distribution(&mut self, variable: &str, distribution: Distribution) {
77 self.distributions
78 .insert(variable.to_string(), distribution);
79 }
80
81 pub fn parse_distributions_from_model(&mut self, model: &ParsedModel) -> Result<(), String> {
87 for (name, scalar) in &model.scalars {
88 if let Some(formula) = &scalar.formula {
89 let formula = formula.trim();
90 let formula_content = formula.strip_prefix('=').unwrap_or(formula);
92
93 if formula_content.starts_with("MC.") {
94 let dist = parse_distribution(formula_content)?;
95 self.add_distribution(name, dist);
96 }
97 }
98 }
99 Ok(())
100 }
101
102 pub fn run(&mut self) -> Result<SimulationResult, String> {
108 let start = Instant::now();
109 let n = self.config.iterations;
110
111 let mut input_samples: HashMap<String, Vec<f64>> = HashMap::new();
113
114 for (var_name, dist) in &self.distributions {
115 let samples = dist.sample_n(self.sampler.rng_mut(), n);
116 input_samples.insert(var_name.clone(), samples);
117 }
118
119 let mut outputs = HashMap::new();
122
123 for output_config in &self.config.outputs {
124 let var = &output_config.variable;
125
126 let samples = input_samples
129 .get(var)
130 .or_else(|| input_samples.get(&format!("scalars.{var}")))
131 .cloned()
132 .unwrap_or_else(|| vec![0.0; n]);
133
134 let statistics = Statistics::from_samples(&samples);
136
137 let histogram = Histogram::from_samples(&samples, 50);
139
140 let mut threshold_probabilities = HashMap::new();
142 if let Some(threshold_str) = &output_config.threshold {
143 if let Ok((op, value)) = parse_threshold(threshold_str) {
144 let prob = evaluate_threshold(&samples, &op, value);
145 threshold_probabilities.insert(threshold_str.clone(), prob);
146 }
147 }
148
149 outputs.insert(
150 var.clone(),
151 OutputResult {
152 variable: var.clone(),
153 statistics,
154 samples,
155 histogram,
156 threshold_probabilities,
157 },
158 );
159 }
160
161 #[allow(clippy::cast_possible_truncation)]
163 let execution_time_ms = start.elapsed().as_millis() as u64;
164
165 Ok(SimulationResult {
166 config: self.config.clone(),
167 iterations_completed: n,
168 execution_time_ms,
169 outputs,
170 input_samples,
171 })
172 }
173
174 pub fn run_with_evaluator<F>(&mut self, mut evaluator: F) -> Result<SimulationResult, String>
186 where
187 F: FnMut(&HashMap<String, f64>) -> HashMap<String, f64>,
188 {
189 let start = Instant::now();
190 let n = self.config.iterations;
191
192 let mut input_samples: HashMap<String, Vec<f64>> = HashMap::new();
194 for (var_name, dist) in &self.distributions {
195 let samples = dist.sample_n(self.sampler.rng_mut(), n);
196 input_samples.insert(var_name.clone(), samples);
197 }
198
199 let output_vars: Vec<String> = self
201 .config
202 .outputs
203 .iter()
204 .map(|o| o.variable.clone())
205 .collect();
206 let mut output_samples: HashMap<String, Vec<f64>> = output_vars
207 .iter()
208 .map(|v| (v.clone(), Vec::with_capacity(n)))
209 .collect();
210
211 for i in 0..n {
213 let mut inputs: HashMap<String, f64> = HashMap::new();
215 for (var, samples) in &input_samples {
216 inputs.insert(var.clone(), samples[i]);
217 }
218
219 let outputs = evaluator(&inputs);
221
222 for var in &output_vars {
224 let value = outputs.get(var).copied().unwrap_or(0.0);
225 output_samples.get_mut(var).unwrap().push(value);
226 }
227 }
228
229 let mut outputs = HashMap::new();
231 for output_config in &self.config.outputs {
232 let var = &output_config.variable;
233 let samples = output_samples.get(var).cloned().unwrap_or_default();
234
235 let statistics = Statistics::from_samples(&samples);
236 let histogram = Histogram::from_samples(&samples, 50);
237
238 let mut threshold_probabilities = HashMap::new();
239 if let Some(threshold_str) = &output_config.threshold {
240 if let Ok((op, value)) = parse_threshold(threshold_str) {
241 let prob = evaluate_threshold(&samples, &op, value);
242 threshold_probabilities.insert(threshold_str.clone(), prob);
243 }
244 }
245
246 outputs.insert(
247 var.clone(),
248 OutputResult {
249 variable: var.clone(),
250 statistics,
251 samples,
252 histogram,
253 threshold_probabilities,
254 },
255 );
256 }
257
258 #[allow(clippy::cast_possible_truncation)]
260 let execution_time_ms = start.elapsed().as_millis() as u64;
261
262 Ok(SimulationResult {
263 config: self.config.clone(),
264 iterations_completed: n,
265 execution_time_ms,
266 outputs,
267 input_samples,
268 })
269 }
270
271 #[must_use]
273 pub const fn sampler(&self) -> &Sampler {
274 &self.sampler
275 }
276
277 pub const fn sampler_mut(&mut self) -> &mut Sampler {
279 &mut self.sampler
280 }
281}
282
283impl SimulationResult {
284 #[must_use]
286 pub fn to_yaml(&self) -> String {
287 use std::fmt::Write;
288
289 let mut output = String::new();
290
291 output.push_str("monte_carlo_results:\n");
292 let _ = writeln!(output, " iterations: {}", self.iterations_completed);
293 let _ = writeln!(output, " execution_time_ms: {}", self.execution_time_ms);
294 let _ = writeln!(output, " sampling: {}", self.config.sampling);
295 if let Some(seed) = self.config.seed {
296 let _ = writeln!(output, " seed: {seed}");
297 }
298
299 output.push_str("\n outputs:\n");
300 for (var, result) in &self.outputs {
301 let _ = writeln!(output, " {var}:");
302 let _ = writeln!(output, " mean: {:.4}", result.statistics.mean);
303 let _ = writeln!(output, " median: {:.4}", result.statistics.median);
304 let _ = writeln!(output, " std_dev: {:.4}", result.statistics.std_dev);
305 let _ = writeln!(output, " min: {:.4}", result.statistics.min);
306 let _ = writeln!(output, " max: {:.4}", result.statistics.max);
307
308 output.push_str(" percentiles:\n");
309 for (p, v) in &result.statistics.percentiles {
310 let _ = writeln!(output, " p{p}: {v:.4}");
311 }
312
313 if !result.threshold_probabilities.is_empty() {
314 output.push_str(" thresholds:\n");
315 for (t, prob) in &result.threshold_probabilities {
316 let _ = writeln!(output, " \"{t}\": {prob:.4}");
317 }
318 }
319 }
320
321 output
322 }
323
324 pub fn to_json(&self) -> Result<String, serde_json::Error> {
330 use serde_json::{json, to_string_pretty};
331
332 let mut outputs_json = serde_json::Map::new();
333 for (var, result) in &self.outputs {
334 let percentiles: serde_json::Map<String, serde_json::Value> = result
335 .statistics
336 .percentiles
337 .iter()
338 .map(|(p, v)| (format!("p{p}"), json!(v)))
339 .collect();
340
341 let thresholds: serde_json::Map<String, serde_json::Value> = result
342 .threshold_probabilities
343 .iter()
344 .map(|(t, p)| (t.clone(), json!(p)))
345 .collect();
346
347 outputs_json.insert(
348 var.clone(),
349 json!({
350 "mean": result.statistics.mean,
351 "median": result.statistics.median,
352 "std_dev": result.statistics.std_dev,
353 "min": result.statistics.min,
354 "max": result.statistics.max,
355 "percentiles": percentiles,
356 "thresholds": thresholds,
357 }),
358 );
359 }
360
361 let result_json = json!({
362 "monte_carlo_results": {
363 "iterations": self.iterations_completed,
364 "execution_time_ms": self.execution_time_ms,
365 "sampling": self.config.sampling,
366 "seed": self.config.seed,
367 "outputs": outputs_json,
368 }
369 });
370
371 to_string_pretty(&result_json)
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use crate::monte_carlo::config::OutputConfig;
379
380 fn test_config() -> MonteCarloConfig {
381 MonteCarloConfig {
382 enabled: true,
383 iterations: 10000,
384 sampling: "latin_hypercube".to_string(),
385 seed: Some(12345),
386 outputs: vec![OutputConfig {
387 variable: "revenue".to_string(),
388 percentiles: vec![10, 50, 90],
389 threshold: Some("> 100000".to_string()),
390 label: None,
391 }],
392 correlations: vec![],
393 }
394 }
395
396 #[test]
397 fn test_engine_creation() {
398 let config = test_config();
399 let engine = MonteCarloEngine::new(config);
400 assert!(engine.is_ok());
401 }
402
403 #[test]
404 fn test_add_distribution() {
405 let config = test_config();
406 let mut engine = MonteCarloEngine::new(config).unwrap();
407
408 let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
409 engine.add_distribution("revenue", dist);
410
411 assert!(engine.distributions.contains_key("revenue"));
412 }
413
414 #[test]
415 fn test_run_simulation() {
416 let config = test_config();
417 let mut engine = MonteCarloEngine::new(config).unwrap();
418
419 let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
420 engine.add_distribution("revenue", dist);
421
422 let result = engine.run().unwrap();
423
424 assert_eq!(result.iterations_completed, 10000);
425 assert!(result.input_samples.contains_key("revenue"));
426 assert!(result.outputs.contains_key("revenue"));
427
428 let revenue_result = &result.outputs["revenue"];
430 assert!((revenue_result.statistics.mean - 100_000.0).abs() < 2_000.0);
431 assert!(revenue_result.statistics.percentiles.contains_key(&50));
432 }
433
434 #[test]
435 fn test_run_with_evaluator() {
436 let config = MonteCarloConfig {
437 enabled: true,
438 iterations: 1000,
439 sampling: "latin_hypercube".to_string(),
440 seed: Some(42),
441 outputs: vec![OutputConfig {
442 variable: "profit".to_string(),
443 percentiles: vec![10, 50, 90],
444 threshold: Some("> 0".to_string()),
445 label: None,
446 }],
447 correlations: vec![],
448 };
449
450 let mut engine = MonteCarloEngine::new(config).unwrap();
451
452 engine.add_distribution("revenue", Distribution::normal(100.0, 10.0).unwrap());
453 engine.add_distribution("costs", Distribution::normal(80.0, 5.0).unwrap());
454
455 let result = engine
456 .run_with_evaluator(|inputs| {
457 let revenue = inputs.get("revenue").copied().unwrap_or(0.0);
458 let costs = inputs.get("costs").copied().unwrap_or(0.0);
459 let mut outputs = HashMap::new();
460 outputs.insert("profit".to_string(), revenue - costs);
461 outputs
462 })
463 .unwrap();
464
465 let profit_result = &result.outputs["profit"];
466 assert!((profit_result.statistics.mean - 20.0).abs() < 3.0);
468
469 let prob = profit_result.threshold_probabilities.get("> 0").unwrap();
471 assert!(*prob > 0.9);
472 }
473
474 #[test]
475 fn test_output_yaml() {
476 let config = test_config();
477 let mut engine = MonteCarloEngine::new(config).unwrap();
478
479 let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
480 engine.add_distribution("revenue", dist);
481
482 let result = engine.run().unwrap();
483 let yaml = result.to_yaml();
484
485 assert!(yaml.contains("monte_carlo_results:"));
486 assert!(yaml.contains("iterations: 10000"));
487 assert!(yaml.contains("mean:"));
488 assert!(yaml.contains("percentiles:"));
489 }
490
491 #[test]
492 fn test_output_json() {
493 let config = test_config();
494 let mut engine = MonteCarloEngine::new(config).unwrap();
495
496 let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
497 engine.add_distribution("revenue", dist);
498
499 let result = engine.run().unwrap();
500 let json = result.to_json().unwrap();
501
502 assert!(json.contains("\"monte_carlo_results\""));
503 assert!(json.contains("\"iterations\": 10000"));
504 assert!(json.contains("\"mean\""));
505 }
506
507 #[test]
508 fn test_seed_reproducibility() {
509 let config = test_config();
510
511 let mut engine1 = MonteCarloEngine::new(config.clone()).unwrap();
512 engine1.add_distribution("revenue", Distribution::normal(100.0, 10.0).unwrap());
513 let result1 = engine1.run().unwrap();
514
515 let mut engine2 = MonteCarloEngine::new(config).unwrap();
516 engine2.add_distribution("revenue", Distribution::normal(100.0, 10.0).unwrap());
517 let result2 = engine2.run().unwrap();
518
519 let samples1 = &result1.input_samples["revenue"];
521 let samples2 = &result2.input_samples["revenue"];
522 assert_eq!(samples1, samples2);
523 }
524}