Skip to main content

mollendorff_forge/monte_carlo/
engine.rs

1//! Monte Carlo Simulation Engine
2//!
3//! Orchestrates the simulation:
4//! 1. Parse distributions from model
5//! 2. Generate samples using specified method
6//! 3. Evaluate formulas for each iteration
7//! 4. Compute output statistics
8
9use std::collections::HashMap;
10use std::str::FromStr;
11use std::time::Instant;
12
13use super::config::MonteCarloConfig;
14use super::distributions::{parse_distribution, Distribution};
15use super::sampler::{Sampler, SamplingMethod};
16use super::statistics::{evaluate_threshold, parse_threshold, Histogram, Statistics};
17use crate::types::ParsedModel;
18
19/// Result of a Monte Carlo simulation
20#[derive(Debug, Clone)]
21pub struct SimulationResult {
22    /// Configuration used
23    pub config: MonteCarloConfig,
24    /// Number of iterations completed
25    pub iterations_completed: usize,
26    /// Execution time in milliseconds
27    pub execution_time_ms: u64,
28    /// Results for each tracked output variable
29    pub outputs: HashMap<String, OutputResult>,
30    /// All sampled values for inputs (variable -> samples)
31    pub input_samples: HashMap<String, Vec<f64>>,
32}
33
34/// Result for a single output variable
35#[derive(Debug, Clone)]
36pub struct OutputResult {
37    /// Variable name
38    pub variable: String,
39    /// Statistics for this output
40    pub statistics: Statistics,
41    /// All simulated values
42    pub samples: Vec<f64>,
43    /// Histogram data
44    pub histogram: Histogram,
45    /// Probability thresholds (threshold string -> probability)
46    pub threshold_probabilities: HashMap<String, f64>,
47}
48
49/// Monte Carlo simulation engine
50pub struct MonteCarloEngine {
51    config: MonteCarloConfig,
52    sampler: Sampler,
53    distributions: HashMap<String, Distribution>,
54}
55
56impl MonteCarloEngine {
57    /// Create a new engine with the given configuration
58    ///
59    /// # Errors
60    ///
61    /// Returns an error if the configuration is invalid (see [`MonteCarloConfig::validate`]).
62    pub fn new(config: MonteCarloConfig) -> Result<Self, String> {
63        config.validate()?;
64
65        let method = SamplingMethod::from_str(&config.sampling)?;
66        let sampler = Sampler::new(method, config.seed);
67
68        Ok(Self {
69            config,
70            sampler,
71            distributions: HashMap::new(),
72        })
73    }
74
75    /// Add a distribution for a variable
76    pub fn add_distribution(&mut self, variable: &str, distribution: Distribution) {
77        self.distributions
78            .insert(variable.to_string(), distribution);
79    }
80
81    /// Parse distributions from a model's scalar formulas
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if any `MC.*` formula fails to parse.
86    pub fn parse_distributions_from_model(&mut self, model: &ParsedModel) -> Result<(), String> {
87        for (name, scalar) in &model.scalars {
88            if let Some(formula) = &scalar.formula {
89                let formula = formula.trim();
90                // Check if formula starts with =MC. or MC.
91                let formula_content = formula.strip_prefix('=').unwrap_or(formula);
92
93                if formula_content.starts_with("MC.") {
94                    let dist = parse_distribution(formula_content)?;
95                    self.add_distribution(name, dist);
96                }
97            }
98        }
99        Ok(())
100    }
101
102    /// Run the simulation
103    ///
104    /// # Errors
105    ///
106    /// Returns an error if output variable samples cannot be resolved.
107    pub fn run(&mut self) -> Result<SimulationResult, String> {
108        let start = Instant::now();
109        let n = self.config.iterations;
110
111        // Generate samples for each distribution
112        let mut input_samples: HashMap<String, Vec<f64>> = HashMap::new();
113
114        for (var_name, dist) in &self.distributions {
115            let samples = dist.sample_n(self.sampler.rng_mut(), n);
116            input_samples.insert(var_name.clone(), samples);
117        }
118
119        // For now, output results are the same as input samples
120        // (Full formula evaluation will be added when integrating with calculator)
121        let mut outputs = HashMap::new();
122
123        for output_config in &self.config.outputs {
124            let var = &output_config.variable;
125
126            // Get samples for this variable (either from inputs or computed)
127            // Try exact match first, then with "scalars." prefix
128            let samples = input_samples
129                .get(var)
130                .or_else(|| input_samples.get(&format!("scalars.{var}")))
131                .cloned()
132                .unwrap_or_else(|| vec![0.0; n]);
133
134            // Calculate statistics
135            let statistics = Statistics::from_samples(&samples);
136
137            // Create histogram (50 bins default)
138            let histogram = Histogram::from_samples(&samples, 50);
139
140            // Evaluate thresholds
141            let mut threshold_probabilities = HashMap::new();
142            if let Some(threshold_str) = &output_config.threshold {
143                if let Ok((op, value)) = parse_threshold(threshold_str) {
144                    let prob = evaluate_threshold(&samples, &op, value);
145                    threshold_probabilities.insert(threshold_str.clone(), prob);
146                }
147            }
148
149            outputs.insert(
150                var.clone(),
151                OutputResult {
152                    variable: var.clone(),
153                    statistics,
154                    samples,
155                    histogram,
156                    threshold_probabilities,
157                },
158            );
159        }
160
161        // cast_possible_truncation: simulation time in ms will never exceed u64::MAX
162        #[allow(clippy::cast_possible_truncation)]
163        let execution_time_ms = start.elapsed().as_millis() as u64;
164
165        Ok(SimulationResult {
166            config: self.config.clone(),
167            iterations_completed: n,
168            execution_time_ms,
169            outputs,
170            input_samples,
171        })
172    }
173
174    /// Run simulation with a custom evaluator function
175    ///
176    /// The evaluator takes input values for one iteration and returns output values.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if output variable samples cannot be resolved.
181    ///
182    /// # Panics
183    ///
184    /// Panics if an output variable configured in config is missing from `output_samples`.
185    pub fn run_with_evaluator<F>(&mut self, mut evaluator: F) -> Result<SimulationResult, String>
186    where
187        F: FnMut(&HashMap<String, f64>) -> HashMap<String, f64>,
188    {
189        let start = Instant::now();
190        let n = self.config.iterations;
191
192        // Generate samples for each distribution
193        let mut input_samples: HashMap<String, Vec<f64>> = HashMap::new();
194        for (var_name, dist) in &self.distributions {
195            let samples = dist.sample_n(self.sampler.rng_mut(), n);
196            input_samples.insert(var_name.clone(), samples);
197        }
198
199        // Initialize output sample storage
200        let output_vars: Vec<String> = self
201            .config
202            .outputs
203            .iter()
204            .map(|o| o.variable.clone())
205            .collect();
206        let mut output_samples: HashMap<String, Vec<f64>> = output_vars
207            .iter()
208            .map(|v| (v.clone(), Vec::with_capacity(n)))
209            .collect();
210
211        // Run iterations
212        for i in 0..n {
213            // Collect input values for this iteration
214            let mut inputs: HashMap<String, f64> = HashMap::new();
215            for (var, samples) in &input_samples {
216                inputs.insert(var.clone(), samples[i]);
217            }
218
219            // Evaluate
220            let outputs = evaluator(&inputs);
221
222            // Store output values
223            for var in &output_vars {
224                let value = outputs.get(var).copied().unwrap_or(0.0);
225                output_samples.get_mut(var).unwrap().push(value);
226            }
227        }
228
229        // Calculate statistics for outputs
230        let mut outputs = HashMap::new();
231        for output_config in &self.config.outputs {
232            let var = &output_config.variable;
233            let samples = output_samples.get(var).cloned().unwrap_or_default();
234
235            let statistics = Statistics::from_samples(&samples);
236            let histogram = Histogram::from_samples(&samples, 50);
237
238            let mut threshold_probabilities = HashMap::new();
239            if let Some(threshold_str) = &output_config.threshold {
240                if let Ok((op, value)) = parse_threshold(threshold_str) {
241                    let prob = evaluate_threshold(&samples, &op, value);
242                    threshold_probabilities.insert(threshold_str.clone(), prob);
243                }
244            }
245
246            outputs.insert(
247                var.clone(),
248                OutputResult {
249                    variable: var.clone(),
250                    statistics,
251                    samples,
252                    histogram,
253                    threshold_probabilities,
254                },
255            );
256        }
257
258        // cast_possible_truncation: simulation time in ms will never exceed u64::MAX
259        #[allow(clippy::cast_possible_truncation)]
260        let execution_time_ms = start.elapsed().as_millis() as u64;
261
262        Ok(SimulationResult {
263            config: self.config.clone(),
264            iterations_completed: n,
265            execution_time_ms,
266            outputs,
267            input_samples,
268        })
269    }
270
271    /// Get the sampler
272    #[must_use]
273    pub const fn sampler(&self) -> &Sampler {
274        &self.sampler
275    }
276
277    /// Get mutable sampler
278    pub const fn sampler_mut(&mut self) -> &mut Sampler {
279        &mut self.sampler
280    }
281}
282
283impl SimulationResult {
284    /// Format results as YAML string
285    #[must_use]
286    pub fn to_yaml(&self) -> String {
287        use std::fmt::Write;
288
289        let mut output = String::new();
290
291        output.push_str("monte_carlo_results:\n");
292        let _ = writeln!(output, "  iterations: {}", self.iterations_completed);
293        let _ = writeln!(output, "  execution_time_ms: {}", self.execution_time_ms);
294        let _ = writeln!(output, "  sampling: {}", self.config.sampling);
295        if let Some(seed) = self.config.seed {
296            let _ = writeln!(output, "  seed: {seed}");
297        }
298
299        output.push_str("\n  outputs:\n");
300        for (var, result) in &self.outputs {
301            let _ = writeln!(output, "    {var}:");
302            let _ = writeln!(output, "      mean: {:.4}", result.statistics.mean);
303            let _ = writeln!(output, "      median: {:.4}", result.statistics.median);
304            let _ = writeln!(output, "      std_dev: {:.4}", result.statistics.std_dev);
305            let _ = writeln!(output, "      min: {:.4}", result.statistics.min);
306            let _ = writeln!(output, "      max: {:.4}", result.statistics.max);
307
308            output.push_str("      percentiles:\n");
309            for (p, v) in &result.statistics.percentiles {
310                let _ = writeln!(output, "        p{p}: {v:.4}");
311            }
312
313            if !result.threshold_probabilities.is_empty() {
314                output.push_str("      thresholds:\n");
315                for (t, prob) in &result.threshold_probabilities {
316                    let _ = writeln!(output, "        \"{t}\": {prob:.4}");
317                }
318            }
319        }
320
321        output
322    }
323
324    /// Format results as JSON string
325    ///
326    /// # Errors
327    ///
328    /// Returns an error if JSON serialization fails.
329    pub fn to_json(&self) -> Result<String, serde_json::Error> {
330        use serde_json::{json, to_string_pretty};
331
332        let mut outputs_json = serde_json::Map::new();
333        for (var, result) in &self.outputs {
334            let percentiles: serde_json::Map<String, serde_json::Value> = result
335                .statistics
336                .percentiles
337                .iter()
338                .map(|(p, v)| (format!("p{p}"), json!(v)))
339                .collect();
340
341            let thresholds: serde_json::Map<String, serde_json::Value> = result
342                .threshold_probabilities
343                .iter()
344                .map(|(t, p)| (t.clone(), json!(p)))
345                .collect();
346
347            outputs_json.insert(
348                var.clone(),
349                json!({
350                    "mean": result.statistics.mean,
351                    "median": result.statistics.median,
352                    "std_dev": result.statistics.std_dev,
353                    "min": result.statistics.min,
354                    "max": result.statistics.max,
355                    "percentiles": percentiles,
356                    "thresholds": thresholds,
357                }),
358            );
359        }
360
361        let result_json = json!({
362            "monte_carlo_results": {
363                "iterations": self.iterations_completed,
364                "execution_time_ms": self.execution_time_ms,
365                "sampling": self.config.sampling,
366                "seed": self.config.seed,
367                "outputs": outputs_json,
368            }
369        });
370
371        to_string_pretty(&result_json)
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use crate::monte_carlo::config::OutputConfig;
379
380    fn test_config() -> MonteCarloConfig {
381        MonteCarloConfig {
382            enabled: true,
383            iterations: 10000,
384            sampling: "latin_hypercube".to_string(),
385            seed: Some(12345),
386            outputs: vec![OutputConfig {
387                variable: "revenue".to_string(),
388                percentiles: vec![10, 50, 90],
389                threshold: Some("> 100000".to_string()),
390                label: None,
391            }],
392            correlations: vec![],
393        }
394    }
395
396    #[test]
397    fn test_engine_creation() {
398        let config = test_config();
399        let engine = MonteCarloEngine::new(config);
400        assert!(engine.is_ok());
401    }
402
403    #[test]
404    fn test_add_distribution() {
405        let config = test_config();
406        let mut engine = MonteCarloEngine::new(config).unwrap();
407
408        let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
409        engine.add_distribution("revenue", dist);
410
411        assert!(engine.distributions.contains_key("revenue"));
412    }
413
414    #[test]
415    fn test_run_simulation() {
416        let config = test_config();
417        let mut engine = MonteCarloEngine::new(config).unwrap();
418
419        let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
420        engine.add_distribution("revenue", dist);
421
422        let result = engine.run().unwrap();
423
424        assert_eq!(result.iterations_completed, 10000);
425        assert!(result.input_samples.contains_key("revenue"));
426        assert!(result.outputs.contains_key("revenue"));
427
428        // Check statistics are reasonable
429        let revenue_result = &result.outputs["revenue"];
430        assert!((revenue_result.statistics.mean - 100_000.0).abs() < 2_000.0);
431        assert!(revenue_result.statistics.percentiles.contains_key(&50));
432    }
433
434    #[test]
435    fn test_run_with_evaluator() {
436        let config = MonteCarloConfig {
437            enabled: true,
438            iterations: 1000,
439            sampling: "latin_hypercube".to_string(),
440            seed: Some(42),
441            outputs: vec![OutputConfig {
442                variable: "profit".to_string(),
443                percentiles: vec![10, 50, 90],
444                threshold: Some("> 0".to_string()),
445                label: None,
446            }],
447            correlations: vec![],
448        };
449
450        let mut engine = MonteCarloEngine::new(config).unwrap();
451
452        engine.add_distribution("revenue", Distribution::normal(100.0, 10.0).unwrap());
453        engine.add_distribution("costs", Distribution::normal(80.0, 5.0).unwrap());
454
455        let result = engine
456            .run_with_evaluator(|inputs| {
457                let revenue = inputs.get("revenue").copied().unwrap_or(0.0);
458                let costs = inputs.get("costs").copied().unwrap_or(0.0);
459                let mut outputs = HashMap::new();
460                outputs.insert("profit".to_string(), revenue - costs);
461                outputs
462            })
463            .unwrap();
464
465        let profit_result = &result.outputs["profit"];
466        // Expected profit mean ≈ 100 - 80 = 20
467        assert!((profit_result.statistics.mean - 20.0).abs() < 3.0);
468
469        // Check threshold probability (profit > 0 should be high)
470        let prob = profit_result.threshold_probabilities.get("> 0").unwrap();
471        assert!(*prob > 0.9);
472    }
473
474    #[test]
475    fn test_output_yaml() {
476        let config = test_config();
477        let mut engine = MonteCarloEngine::new(config).unwrap();
478
479        let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
480        engine.add_distribution("revenue", dist);
481
482        let result = engine.run().unwrap();
483        let yaml = result.to_yaml();
484
485        assert!(yaml.contains("monte_carlo_results:"));
486        assert!(yaml.contains("iterations: 10000"));
487        assert!(yaml.contains("mean:"));
488        assert!(yaml.contains("percentiles:"));
489    }
490
491    #[test]
492    fn test_output_json() {
493        let config = test_config();
494        let mut engine = MonteCarloEngine::new(config).unwrap();
495
496        let dist = Distribution::normal(100_000.0, 15_000.0).unwrap();
497        engine.add_distribution("revenue", dist);
498
499        let result = engine.run().unwrap();
500        let json = result.to_json().unwrap();
501
502        assert!(json.contains("\"monte_carlo_results\""));
503        assert!(json.contains("\"iterations\": 10000"));
504        assert!(json.contains("\"mean\""));
505    }
506
507    #[test]
508    fn test_seed_reproducibility() {
509        let config = test_config();
510
511        let mut engine1 = MonteCarloEngine::new(config.clone()).unwrap();
512        engine1.add_distribution("revenue", Distribution::normal(100.0, 10.0).unwrap());
513        let result1 = engine1.run().unwrap();
514
515        let mut engine2 = MonteCarloEngine::new(config).unwrap();
516        engine2.add_distribution("revenue", Distribution::normal(100.0, 10.0).unwrap());
517        let result2 = engine2.run().unwrap();
518
519        // Same seed should produce identical results
520        let samples1 = &result1.input_samples["revenue"];
521        let samples2 = &result2.input_samples["revenue"];
522        assert_eq!(samples1, samples2);
523    }
524}