Skip to main content

datasynth_core/models/
causal_dag.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3use thiserror::Error;
4
5/// A directed acyclic graph defining causal relationships
6/// between parameters in the generation model.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CausalDAG {
9    pub nodes: Vec<CausalNode>,
10    pub edges: Vec<CausalEdge>,
11    /// Pre-computed topological order (filled at validation time).
12    #[serde(skip)]
13    pub topological_order: Vec<String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalNode {
18    /// Unique identifier (matches config parameter path or abstract name).
19    pub id: String,
20    pub label: String,
21    pub category: NodeCategory,
22    /// Default/baseline value.
23    pub baseline_value: f64,
24    /// Valid range for this parameter.
25    pub bounds: Option<(f64, f64)>,
26    /// Whether this node can be directly intervened upon.
27    #[serde(default = "default_true")]
28    pub interventionable: bool,
29    /// Maps to config path(s) for actual generation parameters.
30    #[serde(default)]
31    pub config_bindings: Vec<String>,
32}
33
34fn default_true() -> bool {
35    true
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum NodeCategory {
41    Macro,
42    Operational,
43    Control,
44    Financial,
45    Behavioral,
46    Regulatory,
47    Outcome,
48    Audit,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct CausalEdge {
53    pub from: String,
54    pub to: String,
55    pub transfer: TransferFunction,
56    /// Delay in months before the effect propagates.
57    #[serde(default)]
58    pub lag_months: u32,
59    /// Strength multiplier (0.0 = no effect, 1.0 = full transfer).
60    #[serde(default = "default_strength")]
61    pub strength: f64,
62    /// Human-readable description of the causal mechanism.
63    pub mechanism: Option<String>,
64}
65
66fn default_strength() -> f64 {
67    1.0
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(tag = "type", rename_all = "snake_case")]
72pub enum TransferFunction {
73    /// output = input * coefficient + intercept
74    Linear {
75        coefficient: f64,
76        #[serde(default)]
77        intercept: f64,
78    },
79    /// output = base * (1 + rate)^input
80    Exponential { base: f64, rate: f64 },
81    /// output = capacity / (1 + e^(-steepness * (input - midpoint)))
82    Logistic {
83        capacity: f64,
84        midpoint: f64,
85        steepness: f64,
86    },
87    /// output = capacity / (1 + e^(steepness * (input - midpoint)))
88    InverseLogistic {
89        capacity: f64,
90        midpoint: f64,
91        steepness: f64,
92    },
93    /// output = magnitude when input crosses threshold, else 0
94    Step { threshold: f64, magnitude: f64 },
95    /// output = magnitude when input > threshold, scaling linearly above
96    Threshold {
97        threshold: f64,
98        magnitude: f64,
99        #[serde(default = "default_saturation")]
100        saturation: f64,
101    },
102    /// output = initial * e^(-decay_rate * input)
103    Decay { initial: f64, decay_rate: f64 },
104    /// Lookup table with linear interpolation between points.
105    Piecewise { points: Vec<(f64, f64)> },
106}
107
108fn default_saturation() -> f64 {
109    f64::INFINITY
110}
111
112impl TransferFunction {
113    /// Compute the output value for a given input.
114    pub fn compute(&self, input: f64) -> f64 {
115        match self {
116            TransferFunction::Linear {
117                coefficient,
118                intercept,
119            } => input * coefficient + intercept,
120
121            TransferFunction::Exponential { base, rate } => base * (1.0 + rate).powf(input),
122
123            TransferFunction::Logistic {
124                capacity,
125                midpoint,
126                steepness,
127            } => capacity / (1.0 + (-steepness * (input - midpoint)).exp()),
128
129            TransferFunction::InverseLogistic {
130                capacity,
131                midpoint,
132                steepness,
133            } => capacity / (1.0 + (steepness * (input - midpoint)).exp()),
134
135            TransferFunction::Step {
136                threshold,
137                magnitude,
138            } => {
139                if input > *threshold {
140                    *magnitude
141                } else {
142                    0.0
143                }
144            }
145
146            TransferFunction::Threshold {
147                threshold,
148                magnitude,
149                saturation,
150            } => {
151                if input > *threshold {
152                    (magnitude * (input - threshold) / threshold).min(*saturation)
153                } else {
154                    0.0
155                }
156            }
157
158            TransferFunction::Decay {
159                initial,
160                decay_rate,
161            } => initial * (-decay_rate * input).exp(),
162
163            TransferFunction::Piecewise { points } => {
164                if points.is_empty() {
165                    return 0.0;
166                }
167                if points.len() == 1 {
168                    return points[0].1;
169                }
170
171                // Sort points by x for interpolation
172                let mut sorted = points.clone();
173                sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
174
175                // Clamp to range
176                if input <= sorted[0].0 {
177                    return sorted[0].1;
178                }
179                if input >= sorted[sorted.len() - 1].0 {
180                    return sorted[sorted.len() - 1].1;
181                }
182
183                // Linear interpolation
184                for window in sorted.windows(2) {
185                    let (x0, y0) = window[0];
186                    let (x1, y1) = window[1];
187                    if input >= x0 && input <= x1 {
188                        let t = (input - x0) / (x1 - x0);
189                        return y0 + t * (y1 - y0);
190                    }
191                }
192
193                sorted[sorted.len() - 1].1
194            }
195        }
196    }
197}
198
199/// Errors that can occur during CausalDAG operations.
200#[derive(Debug, Error)]
201pub enum CausalDAGError {
202    #[error("cycle detected in causal DAG")]
203    CycleDetected,
204    #[error("unknown node referenced in edge: {0}")]
205    UnknownNode(String),
206    #[error("duplicate node ID: {0}")]
207    DuplicateNode(String),
208    #[error("node '{0}' is not interventionable")]
209    NonInterventionable(String),
210}
211
212impl CausalDAG {
213    /// Validate the graph is a DAG (no cycles) and compute topological order.
214    pub fn validate(&mut self) -> Result<(), CausalDAGError> {
215        let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
216
217        // Check for duplicate IDs
218        let mut seen = HashSet::new();
219        for node in &self.nodes {
220            if !seen.insert(&node.id) {
221                return Err(CausalDAGError::DuplicateNode(node.id.clone()));
222            }
223        }
224
225        // Check for unknown nodes in edges
226        for edge in &self.edges {
227            if !node_ids.contains(edge.from.as_str()) {
228                return Err(CausalDAGError::UnknownNode(edge.from.clone()));
229            }
230            if !node_ids.contains(edge.to.as_str()) {
231                return Err(CausalDAGError::UnknownNode(edge.to.clone()));
232            }
233        }
234
235        // Kahn's algorithm for topological sort
236        let mut in_degree: HashMap<&str, usize> = HashMap::new();
237        let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
238
239        for node in &self.nodes {
240            in_degree.insert(&node.id, 0);
241            adjacency.insert(&node.id, Vec::new());
242        }
243
244        for edge in &self.edges {
245            *in_degree.entry(&edge.to).or_insert(0) += 1;
246            adjacency.entry(&edge.from).or_default().push(&edge.to);
247        }
248
249        let mut queue: VecDeque<&str> = VecDeque::new();
250        for (node, &degree) in &in_degree {
251            if degree == 0 {
252                queue.push_back(node);
253            }
254        }
255
256        let mut order = Vec::new();
257        while let Some(node) = queue.pop_front() {
258            order.push(node.to_string());
259            if let Some(neighbors) = adjacency.get(node) {
260                for &neighbor in neighbors {
261                    if let Some(degree) = in_degree.get_mut(neighbor) {
262                        *degree -= 1;
263                        if *degree == 0 {
264                            queue.push_back(neighbor);
265                        }
266                    }
267                }
268            }
269        }
270
271        if order.len() != self.nodes.len() {
272            return Err(CausalDAGError::CycleDetected);
273        }
274
275        self.topological_order = order;
276        Ok(())
277    }
278
279    /// Find a node by its ID.
280    pub fn find_node(&self, id: &str) -> Option<&CausalNode> {
281        self.nodes.iter().find(|n| n.id == id)
282    }
283
284    /// Given a set of interventions (node_id → new_value), propagate
285    /// effects through the DAG in topological order.
286    pub fn propagate(
287        &self,
288        interventions: &HashMap<String, f64>,
289        month: u32,
290    ) -> HashMap<String, f64> {
291        let mut values: HashMap<String, f64> = HashMap::new();
292
293        // Initialize all nodes with baseline values
294        for node in &self.nodes {
295            values.insert(node.id.clone(), node.baseline_value);
296        }
297
298        // Override with direct interventions
299        for (node_id, value) in interventions {
300            values.insert(node_id.clone(), *value);
301        }
302
303        // Build edge lookup: to_node -> list of (from_node, edge)
304        let mut incoming: HashMap<&str, Vec<&CausalEdge>> = HashMap::new();
305        for edge in &self.edges {
306            incoming.entry(&edge.to).or_default().push(edge);
307        }
308
309        // Propagate in topological order
310        for node_id in &self.topological_order {
311            // Skip nodes that are directly intervened upon
312            if interventions.contains_key(node_id) {
313                continue;
314            }
315
316            if let Some(edges) = incoming.get(node_id.as_str()) {
317                let mut total_effect = 0.0;
318                let mut has_effect = false;
319
320                for edge in edges {
321                    // Check lag: only apply if enough months have passed
322                    if month < edge.lag_months {
323                        continue;
324                    }
325
326                    let from_value = values.get(&edge.from).copied().unwrap_or(0.0);
327                    let baseline = self
328                        .find_node(&edge.from)
329                        .map(|n| n.baseline_value)
330                        .unwrap_or(0.0);
331
332                    // Compute the delta from baseline
333                    let delta = from_value - baseline;
334                    if delta.abs() < f64::EPSILON {
335                        continue;
336                    }
337
338                    // Apply transfer function to the delta
339                    let effect = edge.transfer.compute(delta) * edge.strength;
340                    total_effect += effect;
341                    has_effect = true;
342                }
343
344                if has_effect {
345                    let baseline = self
346                        .find_node(node_id)
347                        .map(|n| n.baseline_value)
348                        .unwrap_or(0.0);
349                    let mut new_value = baseline + total_effect;
350
351                    // Clamp to bounds
352                    if let Some(node) = self.find_node(node_id) {
353                        if let Some((min, max)) = node.bounds {
354                            new_value = new_value.clamp(min, max);
355                        }
356                    }
357
358                    values.insert(node_id.clone(), new_value);
359                }
360            }
361        }
362
363        values
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    fn make_node(id: &str, baseline: f64) -> CausalNode {
372        CausalNode {
373            id: id.to_string(),
374            label: id.to_string(),
375            category: NodeCategory::Operational,
376            baseline_value: baseline,
377            bounds: None,
378            interventionable: true,
379            config_bindings: vec![],
380        }
381    }
382
383    fn make_edge(from: &str, to: &str, transfer: TransferFunction) -> CausalEdge {
384        CausalEdge {
385            from: from.to_string(),
386            to: to.to_string(),
387            transfer,
388            lag_months: 0,
389            strength: 1.0,
390            mechanism: None,
391        }
392    }
393
394    #[test]
395    fn test_transfer_function_linear() {
396        let tf = TransferFunction::Linear {
397            coefficient: 0.5,
398            intercept: 1.0,
399        };
400        let result = tf.compute(2.0);
401        assert!((result - 2.0).abs() < f64::EPSILON); // 2.0 * 0.5 + 1.0 = 2.0
402    }
403
404    #[test]
405    fn test_transfer_function_logistic() {
406        let tf = TransferFunction::Logistic {
407            capacity: 1.0,
408            midpoint: 0.0,
409            steepness: 1.0,
410        };
411        // At midpoint, logistic returns capacity/2
412        let result = tf.compute(0.0);
413        assert!((result - 0.5).abs() < 0.001);
414    }
415
416    #[test]
417    fn test_transfer_function_exponential() {
418        let tf = TransferFunction::Exponential {
419            base: 1.0,
420            rate: 1.0,
421        };
422        // base * (1 + rate)^input = 1.0 * 2.0^3.0 = 8.0
423        let result = tf.compute(3.0);
424        assert!((result - 8.0).abs() < 0.001);
425    }
426
427    #[test]
428    fn test_transfer_function_step() {
429        let tf = TransferFunction::Step {
430            threshold: 5.0,
431            magnitude: 10.0,
432        };
433        assert!((tf.compute(3.0) - 0.0).abs() < f64::EPSILON);
434        assert!((tf.compute(6.0) - 10.0).abs() < f64::EPSILON);
435    }
436
437    #[test]
438    fn test_transfer_function_threshold() {
439        let tf = TransferFunction::Threshold {
440            threshold: 2.0,
441            magnitude: 10.0,
442            saturation: f64::INFINITY,
443        };
444        assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON); // below threshold
445                                                               // Above threshold: 10.0 * (3.0 - 2.0) / 2.0 = 5.0
446        assert!((tf.compute(3.0) - 5.0).abs() < 0.001);
447    }
448
449    #[test]
450    fn test_transfer_function_decay() {
451        let tf = TransferFunction::Decay {
452            initial: 100.0,
453            decay_rate: 0.5,
454        };
455        // At input=0: 100.0 * e^0 = 100.0
456        assert!((tf.compute(0.0) - 100.0).abs() < 0.001);
457        // At input=1: 100.0 * e^(-0.5) ≈ 60.65
458        assert!((tf.compute(1.0) - 60.653).abs() < 0.1);
459    }
460
461    #[test]
462    fn test_transfer_function_piecewise() {
463        let tf = TransferFunction::Piecewise {
464            points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0)],
465        };
466        // At 0.5: interpolate between (0,0) and (1,10) → 5.0
467        assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
468        // At 1.5: interpolate between (1,10) and (2,15) → 12.5
469        assert!((tf.compute(1.5) - 12.5).abs() < 0.001);
470        // Below range: clamp to first point
471        assert!((tf.compute(-1.0) - 0.0).abs() < 0.001);
472        // Above range: clamp to last point
473        assert!((tf.compute(3.0) - 15.0).abs() < 0.001);
474    }
475
476    #[test]
477    fn test_dag_validate_acyclic() {
478        let mut dag = CausalDAG {
479            nodes: vec![
480                make_node("a", 1.0),
481                make_node("b", 2.0),
482                make_node("c", 3.0),
483            ],
484            edges: vec![
485                make_edge(
486                    "a",
487                    "b",
488                    TransferFunction::Linear {
489                        coefficient: 1.0,
490                        intercept: 0.0,
491                    },
492                ),
493                make_edge(
494                    "b",
495                    "c",
496                    TransferFunction::Linear {
497                        coefficient: 1.0,
498                        intercept: 0.0,
499                    },
500                ),
501            ],
502            topological_order: vec![],
503        };
504        assert!(dag.validate().is_ok());
505        assert_eq!(dag.topological_order, vec!["a", "b", "c"]);
506    }
507
508    #[test]
509    fn test_dag_validate_cycle_detected() {
510        let mut dag = CausalDAG {
511            nodes: vec![make_node("a", 1.0), make_node("b", 2.0)],
512            edges: vec![
513                make_edge(
514                    "a",
515                    "b",
516                    TransferFunction::Linear {
517                        coefficient: 1.0,
518                        intercept: 0.0,
519                    },
520                ),
521                make_edge(
522                    "b",
523                    "a",
524                    TransferFunction::Linear {
525                        coefficient: 1.0,
526                        intercept: 0.0,
527                    },
528                ),
529            ],
530            topological_order: vec![],
531        };
532        assert!(matches!(dag.validate(), Err(CausalDAGError::CycleDetected)));
533    }
534
535    #[test]
536    fn test_dag_validate_unknown_node() {
537        let mut dag = CausalDAG {
538            nodes: vec![make_node("a", 1.0)],
539            edges: vec![make_edge(
540                "a",
541                "nonexistent",
542                TransferFunction::Linear {
543                    coefficient: 1.0,
544                    intercept: 0.0,
545                },
546            )],
547            topological_order: vec![],
548        };
549        assert!(matches!(
550            dag.validate(),
551            Err(CausalDAGError::UnknownNode(_))
552        ));
553    }
554
555    #[test]
556    fn test_dag_validate_duplicate_node() {
557        let mut dag = CausalDAG {
558            nodes: vec![make_node("a", 1.0), make_node("a", 2.0)],
559            edges: vec![],
560            topological_order: vec![],
561        };
562        assert!(matches!(
563            dag.validate(),
564            Err(CausalDAGError::DuplicateNode(_))
565        ));
566    }
567
568    #[test]
569    fn test_dag_propagate_chain() {
570        let mut dag = CausalDAG {
571            nodes: vec![
572                make_node("a", 10.0),
573                make_node("b", 5.0),
574                make_node("c", 0.0),
575            ],
576            edges: vec![
577                make_edge(
578                    "a",
579                    "b",
580                    TransferFunction::Linear {
581                        coefficient: 0.5,
582                        intercept: 0.0,
583                    },
584                ),
585                make_edge(
586                    "b",
587                    "c",
588                    TransferFunction::Linear {
589                        coefficient: 1.0,
590                        intercept: 0.0,
591                    },
592                ),
593            ],
594            topological_order: vec![],
595        };
596        dag.validate().unwrap();
597
598        // Intervene on A: set to 20.0 (delta = 10.0)
599        let mut interventions = HashMap::new();
600        interventions.insert("a".to_string(), 20.0);
601
602        let result = dag.propagate(&interventions, 0);
603        // A = 20.0 (directly set)
604        assert!((result["a"] - 20.0).abs() < 0.001);
605        // B baseline = 5.0, delta_a = 10.0, transfer = 10.0 * 0.5 + 0.0 = 5.0 → B = 5.0 + 5.0 = 10.0
606        assert!((result["b"] - 10.0).abs() < 0.001);
607        // C baseline = 0.0, delta_b = 5.0, transfer = 5.0 * 1.0 + 0.0 = 5.0 → C = 0.0 + 5.0 = 5.0
608        assert!((result["c"] - 5.0).abs() < 0.001);
609    }
610
611    #[test]
612    fn test_dag_propagate_with_lag() {
613        let mut dag = CausalDAG {
614            nodes: vec![make_node("a", 10.0), make_node("b", 5.0)],
615            edges: vec![CausalEdge {
616                from: "a".to_string(),
617                to: "b".to_string(),
618                transfer: TransferFunction::Linear {
619                    coefficient: 1.0,
620                    intercept: 0.0,
621                },
622                lag_months: 2,
623                strength: 1.0,
624                mechanism: None,
625            }],
626            topological_order: vec![],
627        };
628        dag.validate().unwrap();
629
630        let mut interventions = HashMap::new();
631        interventions.insert("a".to_string(), 20.0);
632
633        // Month 1: lag is 2, so no effect yet
634        let result = dag.propagate(&interventions, 1);
635        assert!((result["b"] - 5.0).abs() < 0.001); // unchanged from baseline
636
637        // Month 2: lag is met, effect propagates
638        let result = dag.propagate(&interventions, 2);
639        // delta_a = 10.0, transfer = 10.0, B = 5.0 + 10.0 = 15.0
640        assert!((result["b"] - 15.0).abs() < 0.001);
641    }
642
643    #[test]
644    fn test_dag_propagate_node_bounds_clamped() {
645        let mut dag = CausalDAG {
646            nodes: vec![make_node("a", 10.0), {
647                let mut n = make_node("b", 5.0);
648                n.bounds = Some((0.0, 8.0));
649                n
650            }],
651            edges: vec![make_edge(
652                "a",
653                "b",
654                TransferFunction::Linear {
655                    coefficient: 1.0,
656                    intercept: 0.0,
657                },
658            )],
659            topological_order: vec![],
660        };
661        dag.validate().unwrap();
662
663        let mut interventions = HashMap::new();
664        interventions.insert("a".to_string(), 20.0); // delta = 10.0 → B would be 15.0
665
666        let result = dag.propagate(&interventions, 0);
667        // B should be clamped to max bound of 8.0
668        assert!((result["b"] - 8.0).abs() < 0.001);
669    }
670
671    #[test]
672    fn test_transfer_function_serde() {
673        let tf = TransferFunction::Linear {
674            coefficient: 0.5,
675            intercept: 1.0,
676        };
677        let json = serde_json::to_string(&tf).unwrap();
678        let deserialized: TransferFunction = serde_json::from_str(&json).unwrap();
679        assert!((deserialized.compute(2.0) - 2.0).abs() < f64::EPSILON);
680    }
681
682    // ====================================================================
683    // Comprehensive transfer function tests (Task 12)
684    // ====================================================================
685
686    #[test]
687    fn test_transfer_function_linear_zero_coefficient() {
688        let tf = TransferFunction::Linear {
689            coefficient: 0.0,
690            intercept: 5.0,
691        };
692        // Any input → just the intercept
693        assert!((tf.compute(0.0) - 5.0).abs() < f64::EPSILON);
694        assert!((tf.compute(100.0) - 5.0).abs() < f64::EPSILON);
695        assert!((tf.compute(-100.0) - 5.0).abs() < f64::EPSILON);
696    }
697
698    #[test]
699    fn test_transfer_function_linear_negative_coefficient() {
700        let tf = TransferFunction::Linear {
701            coefficient: -2.0,
702            intercept: 10.0,
703        };
704        assert!((tf.compute(3.0) - 4.0).abs() < f64::EPSILON); // -6 + 10 = 4
705        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON); // -10 + 10 = 0
706    }
707
708    #[test]
709    fn test_transfer_function_exponential_zero_input() {
710        let tf = TransferFunction::Exponential {
711            base: 5.0,
712            rate: 0.5,
713        };
714        // (1+0.5)^0 = 1, so result = 5.0
715        assert!((tf.compute(0.0) - 5.0).abs() < 0.001);
716    }
717
718    #[test]
719    fn test_transfer_function_exponential_negative_rate() {
720        let tf = TransferFunction::Exponential {
721            base: 100.0,
722            rate: -0.5,
723        };
724        // (1 + (-0.5))^2 = 0.5^2 = 0.25, result = 25.0
725        assert!((tf.compute(2.0) - 25.0).abs() < 0.001);
726    }
727
728    #[test]
729    fn test_transfer_function_logistic_far_from_midpoint() {
730        let tf = TransferFunction::Logistic {
731            capacity: 10.0,
732            midpoint: 5.0,
733            steepness: 2.0,
734        };
735        // Far below midpoint → near 0
736        assert!(tf.compute(-10.0) < 0.01);
737        // Far above midpoint → near capacity
738        assert!((tf.compute(20.0) - 10.0).abs() < 0.01);
739        // At midpoint → capacity/2
740        assert!((tf.compute(5.0) - 5.0).abs() < 0.01);
741    }
742
743    #[test]
744    fn test_transfer_function_logistic_steepness_effect() {
745        // High steepness → sharper transition
746        let steep = TransferFunction::Logistic {
747            capacity: 1.0,
748            midpoint: 0.0,
749            steepness: 10.0,
750        };
751        let gentle = TransferFunction::Logistic {
752            capacity: 1.0,
753            midpoint: 0.0,
754            steepness: 0.5,
755        };
756        // Both should be ~0.5 at midpoint
757        assert!((steep.compute(0.0) - 0.5).abs() < 0.01);
758        assert!((gentle.compute(0.0) - 0.5).abs() < 0.01);
759        // Steep should be closer to 1.0 at input=1.0
760        assert!(steep.compute(1.0) > gentle.compute(1.0));
761    }
762
763    #[test]
764    fn test_transfer_function_inverse_logistic() {
765        let tf = TransferFunction::InverseLogistic {
766            capacity: 1.0,
767            midpoint: 0.0,
768            steepness: 1.0,
769        };
770        // At midpoint → capacity/2
771        assert!((tf.compute(0.0) - 0.5).abs() < 0.001);
772        // Inverse logistic decreases: far above midpoint → near 0
773        assert!(tf.compute(10.0) < 0.01);
774        // Far below midpoint → near capacity
775        assert!((tf.compute(-10.0) - 1.0).abs() < 0.01);
776    }
777
778    #[test]
779    fn test_transfer_function_inverse_logistic_symmetry() {
780        let logistic = TransferFunction::Logistic {
781            capacity: 1.0,
782            midpoint: 0.0,
783            steepness: 1.0,
784        };
785        let inverse = TransferFunction::InverseLogistic {
786            capacity: 1.0,
787            midpoint: 0.0,
788            steepness: 1.0,
789        };
790        // Logistic + InverseLogistic should sum to capacity at any point
791        for x in [-5.0, -1.0, 0.0, 1.0, 5.0] {
792            let sum = logistic.compute(x) + inverse.compute(x);
793            assert!((sum - 1.0).abs() < 0.001, "Sum at x={} was {}", x, sum);
794        }
795    }
796
797    #[test]
798    fn test_transfer_function_step_at_threshold() {
799        let tf = TransferFunction::Step {
800            threshold: 5.0,
801            magnitude: 10.0,
802        };
803        // At exactly the threshold, should be 0 (not strictly greater)
804        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
805        // Just above threshold
806        assert!((tf.compute(5.001) - 10.0).abs() < f64::EPSILON);
807    }
808
809    #[test]
810    fn test_transfer_function_step_negative_magnitude() {
811        let tf = TransferFunction::Step {
812            threshold: 0.0,
813            magnitude: -5.0,
814        };
815        assert!((tf.compute(-1.0) - 0.0).abs() < f64::EPSILON);
816        assert!((tf.compute(1.0) - (-5.0)).abs() < f64::EPSILON);
817    }
818
819    #[test]
820    fn test_transfer_function_threshold_with_saturation() {
821        let tf = TransferFunction::Threshold {
822            threshold: 2.0,
823            magnitude: 10.0,
824            saturation: 8.0,
825        };
826        // Below threshold: 0
827        assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON);
828        // Just above threshold: 10.0 * (2.5 - 2.0) / 2.0 = 2.5
829        assert!((tf.compute(2.5) - 2.5).abs() < 0.001);
830        // Way above threshold without saturation: 10.0 * (100.0 - 2.0) / 2.0 = 490
831        // But capped at saturation=8.0
832        assert!((tf.compute(100.0) - 8.0).abs() < 0.001);
833    }
834
835    #[test]
836    fn test_transfer_function_threshold_infinite_saturation() {
837        let tf = TransferFunction::Threshold {
838            threshold: 1.0,
839            magnitude: 5.0,
840            saturation: f64::INFINITY,
841        };
842        // No saturation cap: grows linearly
843        // 5.0 * (100.0 - 1.0) / 1.0 = 495.0
844        assert!((tf.compute(100.0) - 495.0).abs() < 0.001);
845    }
846
847    #[test]
848    fn test_transfer_function_decay_large_input() {
849        let tf = TransferFunction::Decay {
850            initial: 100.0,
851            decay_rate: 1.0,
852        };
853        // Large input → approaches 0
854        assert!(tf.compute(10.0) < 0.01);
855        assert!(tf.compute(20.0) < 0.0001);
856    }
857
858    #[test]
859    fn test_transfer_function_decay_zero_rate() {
860        let tf = TransferFunction::Decay {
861            initial: 50.0,
862            decay_rate: 0.0,
863        };
864        // No decay → constant
865        assert!((tf.compute(0.0) - 50.0).abs() < f64::EPSILON);
866        assert!((tf.compute(100.0) - 50.0).abs() < f64::EPSILON);
867    }
868
869    #[test]
870    fn test_transfer_function_piecewise_single_point() {
871        let tf = TransferFunction::Piecewise {
872            points: vec![(5.0, 42.0)],
873        };
874        // Single point → always returns that value
875        assert!((tf.compute(0.0) - 42.0).abs() < f64::EPSILON);
876        assert!((tf.compute(100.0) - 42.0).abs() < f64::EPSILON);
877    }
878
879    #[test]
880    fn test_transfer_function_piecewise_empty() {
881        let tf = TransferFunction::Piecewise { points: vec![] };
882        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
883    }
884
885    #[test]
886    fn test_transfer_function_piecewise_exact_points() {
887        let tf = TransferFunction::Piecewise {
888            points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0), (3.0, 30.0)],
889        };
890        // At exact breakpoints
891        assert!((tf.compute(0.0) - 0.0).abs() < 0.001);
892        assert!((tf.compute(1.0) - 10.0).abs() < 0.001);
893        assert!((tf.compute(2.0) - 15.0).abs() < 0.001);
894        assert!((tf.compute(3.0) - 30.0).abs() < 0.001);
895    }
896
897    #[test]
898    fn test_transfer_function_piecewise_unsorted_points() {
899        // Points given out of order — should still interpolate correctly
900        let tf = TransferFunction::Piecewise {
901            points: vec![(2.0, 20.0), (0.0, 0.0), (1.0, 10.0)],
902        };
903        assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
904        assert!((tf.compute(1.5) - 15.0).abs() < 0.001);
905    }
906}