Skip to main content

mollendorff_forge/tornado/
engine.rs

1//! Tornado Diagram Engine
2//!
3//! Performs one-at-a-time sensitivity analysis.
4
5use super::config::{InputRange, TornadoConfig};
6use crate::core::ArrayCalculator;
7use crate::types::{ParsedModel, Variable};
8use serde::{Deserialize, Serialize};
9
10/// A single sensitivity bar in the tornado diagram
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SensitivityBar {
13    /// Input variable name
14    pub input_name: String,
15    /// Output at low input value
16    pub output_at_low: f64,
17    /// Output at high input value
18    pub output_at_high: f64,
19    /// Total swing (high - low output)
20    pub swing: f64,
21    /// Absolute swing for sorting
22    pub abs_swing: f64,
23    /// Low input value used
24    pub input_low: f64,
25    /// High input value used
26    pub input_high: f64,
27}
28
29impl SensitivityBar {
30    /// Generate ASCII bar representation
31    // Truncation is mathematically impossible: ratio is 0..=1, bar_width is small
32    #[allow(
33        clippy::cast_possible_truncation,
34        clippy::cast_sign_loss,
35        clippy::cast_precision_loss
36    )]
37    #[must_use]
38    pub fn to_ascii(&self, max_swing: f64, bar_width: usize) -> String {
39        let ratio = self.abs_swing / max_swing;
40        let filled = (ratio * bar_width as f64) as usize;
41        let bar: String = "█".repeat(filled);
42        format!(
43            "{:<20} |{:<width$}| +/- ${:.0}",
44            self.input_name,
45            bar,
46            self.abs_swing / 2.0,
47            width = bar_width
48        )
49    }
50}
51
52/// Complete tornado analysis result
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TornadoResult {
55    /// Output variable name
56    pub output: String,
57    /// Base case output value
58    pub base_value: f64,
59    /// Sensitivity bars (sorted by impact)
60    pub bars: Vec<SensitivityBar>,
61    /// Total variance explained
62    pub total_variance: f64,
63}
64
65impl TornadoResult {
66    /// Export results to YAML format
67    #[must_use]
68    pub fn to_yaml(&self) -> String {
69        serde_yaml_ng::to_string(self).unwrap_or_else(|_| "# Error serializing results".to_string())
70    }
71
72    /// Export results to JSON format
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if JSON serialization fails.
77    pub fn to_json(&self) -> Result<String, serde_json::Error> {
78        serde_json::to_string_pretty(self)
79    }
80
81    /// Generate ASCII tornado diagram
82    pub fn to_ascii(&self) -> String {
83        use std::fmt::Write;
84
85        let mut output = String::new();
86
87        let _ = write!(
88            output,
89            "{} Sensitivity (Base: ${:.0})\n\n",
90            self.output, self.base_value
91        );
92
93        if self.bars.is_empty() {
94            output.push_str("No sensitivity data\n");
95            return output;
96        }
97
98        let max_swing = self
99            .bars
100            .iter()
101            .map(|b| b.abs_swing)
102            .fold(0.0_f64, f64::max);
103
104        for bar in &self.bars {
105            output.push_str(&bar.to_ascii(max_swing, 30));
106            output.push('\n');
107        }
108
109        output
110    }
111
112    /// Get top N drivers
113    #[must_use]
114    pub fn top_drivers(&self, n: usize) -> Vec<&SensitivityBar> {
115        self.bars.iter().take(n).collect()
116    }
117
118    /// Calculate percentage of variance explained by top N drivers
119    #[must_use]
120    pub fn variance_explained_by_top(&self, n: usize) -> f64 {
121        if self.total_variance == 0.0 {
122            return 0.0;
123        }
124        let top_variance: f64 = self.bars.iter().take(n).map(|b| b.abs_swing).sum();
125        top_variance / self.total_variance * 100.0
126    }
127}
128
129/// Tornado Diagram Engine
130pub struct TornadoEngine {
131    config: TornadoConfig,
132    base_model: ParsedModel,
133}
134
135impl TornadoEngine {
136    /// Create a new tornado engine
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if the configuration is invalid.
141    pub fn new(config: TornadoConfig, base_model: ParsedModel) -> Result<Self, String> {
142        config.validate()?;
143        Ok(Self { config, base_model })
144    }
145
146    /// Run the sensitivity analysis
147    ///
148    /// # Errors
149    ///
150    /// Returns an error if the base model or any sensitivity calculation fails.
151    pub fn analyze(&self) -> Result<TornadoResult, String> {
152        // Calculate base case
153        let base_value = self.calculate_output(&self.base_model)?;
154
155        // Determine if we need cross-scale normalization
156        // Get base values for all inputs to check if they span multiple orders of magnitude
157        let input_bases: Vec<f64> = self
158            .config
159            .inputs
160            .iter()
161            .filter_map(|input| {
162                input.base.or_else(|| {
163                    self.base_model
164                        .scalars
165                        .get(&input.name)
166                        .and_then(|v| v.value)
167                })
168            })
169            .map(f64::abs)
170            .filter(|v| *v > 1e-10)
171            .collect();
172
173        // Check if inputs span vastly different scales (e.g., dollars vs rates)
174        let needs_normalization = if input_bases.len() >= 2 {
175            let max_base = input_bases.iter().fold(0.0_f64, |a, &b| a.max(b));
176            let min_base = input_bases.iter().fold(f64::INFINITY, |a, &b| a.min(b));
177            max_base / min_base > 100.0 // More than 2 orders of magnitude difference
178        } else {
179            false
180        };
181
182        // Calculate sensitivity for each input
183        let mut bars: Vec<SensitivityBar> = Vec::new();
184
185        for input in &self.config.inputs {
186            let bar = self.calculate_sensitivity(input, base_value, needs_normalization)?;
187            bars.push(bar);
188        }
189
190        // Sort by absolute swing (largest impact first)
191        bars.sort_by(|a, b| {
192            b.abs_swing
193                .partial_cmp(&a.abs_swing)
194                .unwrap_or(std::cmp::Ordering::Equal)
195        });
196
197        // Calculate total variance
198        let total_variance: f64 = bars.iter().map(|b| b.abs_swing).sum();
199
200        Ok(TornadoResult {
201            output: self.config.output.clone(),
202            base_value,
203            bars,
204            total_variance,
205        })
206    }
207
208    /// Calculate sensitivity for a single input
209    fn calculate_sensitivity(
210        &self,
211        input: &InputRange,
212        _base_value: f64,
213        needs_normalization: bool,
214    ) -> Result<SensitivityBar, String> {
215        // Calculate output at low input value
216        let output_at_low = self.calculate_with_override(&input.name, input.low)?;
217
218        // Calculate output at high input value
219        let output_at_high = self.calculate_with_override(&input.name, input.high)?;
220
221        let swing = output_at_high - output_at_low;
222        let raw_abs_swing = swing.abs();
223
224        // Apply normalization only when comparing inputs of vastly different scales
225        let abs_swing = if needs_normalization {
226            // Get the base input value from the model to calculate relative range
227            let input_base = input
228                .base
229                .or_else(|| {
230                    self.base_model
231                        .scalars
232                        .get(&input.name)
233                        .and_then(|v| v.value)
234                })
235                .unwrap_or(1.0); // Default to 1.0 if no base found
236
237            // Calculate relative input range (as fraction of base)
238            // This normalizes inputs of different scales (e.g., dollars vs rates)
239            let input_range = input.high - input.low;
240            let relative_range = if input_base.abs() > 1e-10 {
241                input_range / input_base.abs()
242            } else {
243                1.0 // Avoid division by zero
244            };
245
246            // Weight sensitivity by square of relative range
247            // This ensures inputs varied by larger percentages rank higher
248            raw_abs_swing * relative_range * relative_range
249        } else {
250            // For inputs of similar scale, use raw absolute swing
251            raw_abs_swing
252        };
253
254        Ok(SensitivityBar {
255            input_name: input.name.clone(),
256            output_at_low,
257            output_at_high,
258            swing,
259            abs_swing,
260            input_low: input.low,
261            input_high: input.high,
262        })
263    }
264
265    /// Calculate output with a specific input override
266    fn calculate_with_override(&self, input_name: &str, input_value: f64) -> Result<f64, String> {
267        let mut model = self.base_model.clone();
268
269        // Override the input value
270        if let Some(scalar) = model.scalars.get_mut(input_name) {
271            scalar.value = Some(input_value);
272            scalar.formula = None; // Clear formula to use override value
273        } else {
274            // Create new scalar
275            model.scalars.insert(
276                input_name.to_string(),
277                Variable::new(input_name.to_string(), Some(input_value), None),
278            );
279        }
280
281        self.calculate_output(&model)
282    }
283
284    /// Calculate the output variable value
285    fn calculate_output(&self, model: &ParsedModel) -> Result<f64, String> {
286        let calculator = ArrayCalculator::new(model.clone());
287        let result = calculator.calculate_all().map_err(|e| e.to_string())?;
288
289        result
290            .scalars
291            .get(&self.config.output)
292            .and_then(|v| v.value)
293            .ok_or_else(|| {
294                format!(
295                    "Output variable '{}' not found or has no value",
296                    self.config.output
297                )
298            })
299    }
300
301    /// Get the configuration
302    #[must_use]
303    pub const fn config(&self) -> &TornadoConfig {
304        &self.config
305    }
306}
307
308#[cfg(test)]
309mod engine_tests {
310    use super::*;
311
312    fn create_test_model() -> ParsedModel {
313        let mut model = ParsedModel::new();
314
315        // Inputs
316        model.scalars.insert(
317            "revenue".to_string(),
318            Variable::new("revenue".to_string(), Some(1_000_000.0), None),
319        );
320        model.scalars.insert(
321            "cost_rate".to_string(),
322            Variable::new("cost_rate".to_string(), Some(0.60), None),
323        );
324        model.scalars.insert(
325            "tax_rate".to_string(),
326            Variable::new("tax_rate".to_string(), Some(0.25), None),
327        );
328
329        // Output: profit = revenue * (1 - cost_rate) * (1 - tax_rate)
330        model.scalars.insert(
331            "profit".to_string(),
332            Variable::new(
333                "profit".to_string(),
334                None,
335                Some("=revenue * (1 - cost_rate) * (1 - tax_rate)".to_string()),
336            ),
337        );
338
339        model
340    }
341
342    #[test]
343    fn test_tornado_analysis() {
344        let model = create_test_model();
345        let config = TornadoConfig::new("profit")
346            .with_input(InputRange::new("revenue", 800_000.0, 1_200_000.0))
347            .with_input(InputRange::new("cost_rate", 0.50, 0.70))
348            .with_input(InputRange::new("tax_rate", 0.20, 0.30));
349
350        let engine = TornadoEngine::new(config, model).unwrap();
351        let result = engine.analyze().unwrap();
352
353        // Should have 3 bars
354        assert_eq!(result.bars.len(), 3);
355
356        // Bars should be sorted by impact
357        for i in 0..result.bars.len() - 1 {
358            assert!(
359                result.bars[i].abs_swing >= result.bars[i + 1].abs_swing,
360                "Bars should be sorted by impact"
361            );
362        }
363
364        // Revenue should have the biggest impact (absolute dollars)
365        assert_eq!(result.bars[0].input_name, "revenue");
366    }
367
368    #[test]
369    fn test_ascii_output() {
370        let model = create_test_model();
371        let config = TornadoConfig::new("profit")
372            .with_input(InputRange::new("revenue", 800_000.0, 1_200_000.0))
373            .with_input(InputRange::new("cost_rate", 0.50, 0.70));
374
375        let engine = TornadoEngine::new(config, model).unwrap();
376        let result = engine.analyze().unwrap();
377        let ascii = result.to_ascii();
378
379        assert!(ascii.contains("profit Sensitivity"));
380        assert!(ascii.contains("revenue"));
381        assert!(ascii.contains("cost_rate"));
382    }
383
384    #[test]
385    fn test_top_drivers() {
386        let model = create_test_model();
387        let config = TornadoConfig::new("profit")
388            .with_input(InputRange::new("revenue", 800_000.0, 1_200_000.0))
389            .with_input(InputRange::new("cost_rate", 0.50, 0.70))
390            .with_input(InputRange::new("tax_rate", 0.20, 0.30));
391
392        let engine = TornadoEngine::new(config, model).unwrap();
393        let result = engine.analyze().unwrap();
394
395        let top_2 = result.top_drivers(2);
396        assert_eq!(top_2.len(), 2);
397
398        // Check variance explained
399        let pct = result.variance_explained_by_top(2);
400        assert!(pct > 50.0, "Top 2 should explain > 50% of variance");
401    }
402
403    #[test]
404    fn test_yaml_export() {
405        let model = create_test_model();
406        let config = TornadoConfig::new("profit").with_input(InputRange::new(
407            "revenue",
408            800_000.0,
409            1_200_000.0,
410        ));
411
412        let engine = TornadoEngine::new(config, model).unwrap();
413        let result = engine.analyze().unwrap();
414        let yaml = result.to_yaml();
415
416        assert!(yaml.contains("output: profit"));
417        assert!(yaml.contains("bars:"));
418    }
419
420    #[test]
421    fn test_json_export() {
422        let model = create_test_model();
423        let config = TornadoConfig::new("profit").with_input(InputRange::new(
424            "revenue",
425            800_000.0,
426            1_200_000.0,
427        ));
428
429        let engine = TornadoEngine::new(config, model).unwrap();
430        let result = engine.analyze().unwrap();
431        let json = result.to_json().unwrap();
432
433        assert!(json.contains("\"output\""));
434        assert!(json.contains("\"bars\""));
435    }
436}