mollendorff_forge/tornado/
engine.rs1use super::config::{InputRange, TornadoConfig};
6use crate::core::ArrayCalculator;
7use crate::types::{ParsedModel, Variable};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SensitivityBar {
13 pub input_name: String,
15 pub output_at_low: f64,
17 pub output_at_high: f64,
19 pub swing: f64,
21 pub abs_swing: f64,
23 pub input_low: f64,
25 pub input_high: f64,
27}
28
29impl SensitivityBar {
30 #[allow(
33 clippy::cast_possible_truncation,
34 clippy::cast_sign_loss,
35 clippy::cast_precision_loss
36 )]
37 #[must_use]
38 pub fn to_ascii(&self, max_swing: f64, bar_width: usize) -> String {
39 let ratio = self.abs_swing / max_swing;
40 let filled = (ratio * bar_width as f64) as usize;
41 let bar: String = "█".repeat(filled);
42 format!(
43 "{:<20} |{:<width$}| +/- ${:.0}",
44 self.input_name,
45 bar,
46 self.abs_swing / 2.0,
47 width = bar_width
48 )
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TornadoResult {
55 pub output: String,
57 pub base_value: f64,
59 pub bars: Vec<SensitivityBar>,
61 pub total_variance: f64,
63}
64
65impl TornadoResult {
66 #[must_use]
68 pub fn to_yaml(&self) -> String {
69 serde_yaml_ng::to_string(self).unwrap_or_else(|_| "# Error serializing results".to_string())
70 }
71
72 pub fn to_json(&self) -> Result<String, serde_json::Error> {
78 serde_json::to_string_pretty(self)
79 }
80
81 pub fn to_ascii(&self) -> String {
83 use std::fmt::Write;
84
85 let mut output = String::new();
86
87 let _ = write!(
88 output,
89 "{} Sensitivity (Base: ${:.0})\n\n",
90 self.output, self.base_value
91 );
92
93 if self.bars.is_empty() {
94 output.push_str("No sensitivity data\n");
95 return output;
96 }
97
98 let max_swing = self
99 .bars
100 .iter()
101 .map(|b| b.abs_swing)
102 .fold(0.0_f64, f64::max);
103
104 for bar in &self.bars {
105 output.push_str(&bar.to_ascii(max_swing, 30));
106 output.push('\n');
107 }
108
109 output
110 }
111
112 #[must_use]
114 pub fn top_drivers(&self, n: usize) -> Vec<&SensitivityBar> {
115 self.bars.iter().take(n).collect()
116 }
117
118 #[must_use]
120 pub fn variance_explained_by_top(&self, n: usize) -> f64 {
121 if self.total_variance == 0.0 {
122 return 0.0;
123 }
124 let top_variance: f64 = self.bars.iter().take(n).map(|b| b.abs_swing).sum();
125 top_variance / self.total_variance * 100.0
126 }
127}
128
129pub struct TornadoEngine {
131 config: TornadoConfig,
132 base_model: ParsedModel,
133}
134
135impl TornadoEngine {
136 pub fn new(config: TornadoConfig, base_model: ParsedModel) -> Result<Self, String> {
142 config.validate()?;
143 Ok(Self { config, base_model })
144 }
145
146 pub fn analyze(&self) -> Result<TornadoResult, String> {
152 let base_value = self.calculate_output(&self.base_model)?;
154
155 let input_bases: Vec<f64> = self
158 .config
159 .inputs
160 .iter()
161 .filter_map(|input| {
162 input.base.or_else(|| {
163 self.base_model
164 .scalars
165 .get(&input.name)
166 .and_then(|v| v.value)
167 })
168 })
169 .map(f64::abs)
170 .filter(|v| *v > 1e-10)
171 .collect();
172
173 let needs_normalization = if input_bases.len() >= 2 {
175 let max_base = input_bases.iter().fold(0.0_f64, |a, &b| a.max(b));
176 let min_base = input_bases.iter().fold(f64::INFINITY, |a, &b| a.min(b));
177 max_base / min_base > 100.0 } else {
179 false
180 };
181
182 let mut bars: Vec<SensitivityBar> = Vec::new();
184
185 for input in &self.config.inputs {
186 let bar = self.calculate_sensitivity(input, base_value, needs_normalization)?;
187 bars.push(bar);
188 }
189
190 bars.sort_by(|a, b| {
192 b.abs_swing
193 .partial_cmp(&a.abs_swing)
194 .unwrap_or(std::cmp::Ordering::Equal)
195 });
196
197 let total_variance: f64 = bars.iter().map(|b| b.abs_swing).sum();
199
200 Ok(TornadoResult {
201 output: self.config.output.clone(),
202 base_value,
203 bars,
204 total_variance,
205 })
206 }
207
208 fn calculate_sensitivity(
210 &self,
211 input: &InputRange,
212 _base_value: f64,
213 needs_normalization: bool,
214 ) -> Result<SensitivityBar, String> {
215 let output_at_low = self.calculate_with_override(&input.name, input.low)?;
217
218 let output_at_high = self.calculate_with_override(&input.name, input.high)?;
220
221 let swing = output_at_high - output_at_low;
222 let raw_abs_swing = swing.abs();
223
224 let abs_swing = if needs_normalization {
226 let input_base = input
228 .base
229 .or_else(|| {
230 self.base_model
231 .scalars
232 .get(&input.name)
233 .and_then(|v| v.value)
234 })
235 .unwrap_or(1.0); let input_range = input.high - input.low;
240 let relative_range = if input_base.abs() > 1e-10 {
241 input_range / input_base.abs()
242 } else {
243 1.0 };
245
246 raw_abs_swing * relative_range * relative_range
249 } else {
250 raw_abs_swing
252 };
253
254 Ok(SensitivityBar {
255 input_name: input.name.clone(),
256 output_at_low,
257 output_at_high,
258 swing,
259 abs_swing,
260 input_low: input.low,
261 input_high: input.high,
262 })
263 }
264
265 fn calculate_with_override(&self, input_name: &str, input_value: f64) -> Result<f64, String> {
267 let mut model = self.base_model.clone();
268
269 if let Some(scalar) = model.scalars.get_mut(input_name) {
271 scalar.value = Some(input_value);
272 scalar.formula = None; } else {
274 model.scalars.insert(
276 input_name.to_string(),
277 Variable::new(input_name.to_string(), Some(input_value), None),
278 );
279 }
280
281 self.calculate_output(&model)
282 }
283
284 fn calculate_output(&self, model: &ParsedModel) -> Result<f64, String> {
286 let calculator = ArrayCalculator::new(model.clone());
287 let result = calculator.calculate_all().map_err(|e| e.to_string())?;
288
289 result
290 .scalars
291 .get(&self.config.output)
292 .and_then(|v| v.value)
293 .ok_or_else(|| {
294 format!(
295 "Output variable '{}' not found or has no value",
296 self.config.output
297 )
298 })
299 }
300
301 #[must_use]
303 pub const fn config(&self) -> &TornadoConfig {
304 &self.config
305 }
306}
307
308#[cfg(test)]
309mod engine_tests {
310 use super::*;
311
312 fn create_test_model() -> ParsedModel {
313 let mut model = ParsedModel::new();
314
315 model.scalars.insert(
317 "revenue".to_string(),
318 Variable::new("revenue".to_string(), Some(1_000_000.0), None),
319 );
320 model.scalars.insert(
321 "cost_rate".to_string(),
322 Variable::new("cost_rate".to_string(), Some(0.60), None),
323 );
324 model.scalars.insert(
325 "tax_rate".to_string(),
326 Variable::new("tax_rate".to_string(), Some(0.25), None),
327 );
328
329 model.scalars.insert(
331 "profit".to_string(),
332 Variable::new(
333 "profit".to_string(),
334 None,
335 Some("=revenue * (1 - cost_rate) * (1 - tax_rate)".to_string()),
336 ),
337 );
338
339 model
340 }
341
342 #[test]
343 fn test_tornado_analysis() {
344 let model = create_test_model();
345 let config = TornadoConfig::new("profit")
346 .with_input(InputRange::new("revenue", 800_000.0, 1_200_000.0))
347 .with_input(InputRange::new("cost_rate", 0.50, 0.70))
348 .with_input(InputRange::new("tax_rate", 0.20, 0.30));
349
350 let engine = TornadoEngine::new(config, model).unwrap();
351 let result = engine.analyze().unwrap();
352
353 assert_eq!(result.bars.len(), 3);
355
356 for i in 0..result.bars.len() - 1 {
358 assert!(
359 result.bars[i].abs_swing >= result.bars[i + 1].abs_swing,
360 "Bars should be sorted by impact"
361 );
362 }
363
364 assert_eq!(result.bars[0].input_name, "revenue");
366 }
367
368 #[test]
369 fn test_ascii_output() {
370 let model = create_test_model();
371 let config = TornadoConfig::new("profit")
372 .with_input(InputRange::new("revenue", 800_000.0, 1_200_000.0))
373 .with_input(InputRange::new("cost_rate", 0.50, 0.70));
374
375 let engine = TornadoEngine::new(config, model).unwrap();
376 let result = engine.analyze().unwrap();
377 let ascii = result.to_ascii();
378
379 assert!(ascii.contains("profit Sensitivity"));
380 assert!(ascii.contains("revenue"));
381 assert!(ascii.contains("cost_rate"));
382 }
383
384 #[test]
385 fn test_top_drivers() {
386 let model = create_test_model();
387 let config = TornadoConfig::new("profit")
388 .with_input(InputRange::new("revenue", 800_000.0, 1_200_000.0))
389 .with_input(InputRange::new("cost_rate", 0.50, 0.70))
390 .with_input(InputRange::new("tax_rate", 0.20, 0.30));
391
392 let engine = TornadoEngine::new(config, model).unwrap();
393 let result = engine.analyze().unwrap();
394
395 let top_2 = result.top_drivers(2);
396 assert_eq!(top_2.len(), 2);
397
398 let pct = result.variance_explained_by_top(2);
400 assert!(pct > 50.0, "Top 2 should explain > 50% of variance");
401 }
402
403 #[test]
404 fn test_yaml_export() {
405 let model = create_test_model();
406 let config = TornadoConfig::new("profit").with_input(InputRange::new(
407 "revenue",
408 800_000.0,
409 1_200_000.0,
410 ));
411
412 let engine = TornadoEngine::new(config, model).unwrap();
413 let result = engine.analyze().unwrap();
414 let yaml = result.to_yaml();
415
416 assert!(yaml.contains("output: profit"));
417 assert!(yaml.contains("bars:"));
418 }
419
420 #[test]
421 fn test_json_export() {
422 let model = create_test_model();
423 let config = TornadoConfig::new("profit").with_input(InputRange::new(
424 "revenue",
425 800_000.0,
426 1_200_000.0,
427 ));
428
429 let engine = TornadoEngine::new(config, model).unwrap();
430 let result = engine.analyze().unwrap();
431 let json = result.to_json().unwrap();
432
433 assert!(json.contains("\"output\""));
434 assert!(json.contains("\"bars\""));
435 }
436}