Skip to main content

datasynth_core/models/
causal_dag.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3use thiserror::Error;
4
5/// A directed acyclic graph defining causal relationships
6/// between parameters in the generation model.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CausalDAG {
9    pub nodes: Vec<CausalNode>,
10    pub edges: Vec<CausalEdge>,
11    /// Pre-computed topological order (filled at validation time).
12    #[serde(skip)]
13    pub topological_order: Vec<String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalNode {
18    /// Unique identifier (matches config parameter path or abstract name).
19    pub id: String,
20    pub label: String,
21    pub category: NodeCategory,
22    /// Default/baseline value.
23    pub baseline_value: f64,
24    /// Valid range for this parameter.
25    pub bounds: Option<(f64, f64)>,
26    /// Whether this node can be directly intervened upon.
27    #[serde(default = "default_true")]
28    pub interventionable: bool,
29    /// Maps to config path(s) for actual generation parameters.
30    #[serde(default)]
31    pub config_bindings: Vec<String>,
32}
33
34fn default_true() -> bool {
35    true
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum NodeCategory {
41    Macro,
42    Operational,
43    Control,
44    Financial,
45    Behavioral,
46    Regulatory,
47    Outcome,
48    Audit,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct CausalEdge {
53    pub from: String,
54    pub to: String,
55    pub transfer: TransferFunction,
56    /// Delay in months before the effect propagates.
57    #[serde(default)]
58    pub lag_months: u32,
59    /// Strength multiplier (0.0 = no effect, 1.0 = full transfer).
60    #[serde(default = "default_strength")]
61    pub strength: f64,
62    /// Human-readable description of the causal mechanism.
63    pub mechanism: Option<String>,
64}
65
66fn default_strength() -> f64 {
67    1.0
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(tag = "type", rename_all = "snake_case")]
72pub enum TransferFunction {
73    /// output = input * coefficient + intercept
74    Linear {
75        coefficient: f64,
76        #[serde(default)]
77        intercept: f64,
78    },
79    /// output = base * (1 + rate)^input
80    Exponential { base: f64, rate: f64 },
81    /// output = capacity / (1 + e^(-steepness * (input - midpoint)))
82    Logistic {
83        capacity: f64,
84        midpoint: f64,
85        steepness: f64,
86    },
87    /// output = capacity / (1 + e^(steepness * (input - midpoint)))
88    InverseLogistic {
89        capacity: f64,
90        midpoint: f64,
91        steepness: f64,
92    },
93    /// output = magnitude when input crosses threshold, else 0
94    Step { threshold: f64, magnitude: f64 },
95    /// output = magnitude when input > threshold, scaling linearly above
96    Threshold {
97        threshold: f64,
98        magnitude: f64,
99        #[serde(default = "default_saturation")]
100        saturation: f64,
101    },
102    /// output = initial * e^(-decay_rate * input)
103    Decay { initial: f64, decay_rate: f64 },
104    /// Lookup table with linear interpolation between points.
105    Piecewise { points: Vec<(f64, f64)> },
106}
107
108fn default_saturation() -> f64 {
109    f64::INFINITY
110}
111
112impl TransferFunction {
113    /// Compute the output value for a given input.
114    pub fn compute(&self, input: f64) -> f64 {
115        match self {
116            TransferFunction::Linear {
117                coefficient,
118                intercept,
119            } => input * coefficient + intercept,
120
121            TransferFunction::Exponential { base, rate } => base * (1.0 + rate).powf(input),
122
123            TransferFunction::Logistic {
124                capacity,
125                midpoint,
126                steepness,
127            } => capacity / (1.0 + (-steepness * (input - midpoint)).exp()),
128
129            TransferFunction::InverseLogistic {
130                capacity,
131                midpoint,
132                steepness,
133            } => capacity / (1.0 + (steepness * (input - midpoint)).exp()),
134
135            TransferFunction::Step {
136                threshold,
137                magnitude,
138            } => {
139                if input > *threshold {
140                    *magnitude
141                } else {
142                    0.0
143                }
144            }
145
146            TransferFunction::Threshold {
147                threshold,
148                magnitude,
149                saturation,
150            } => {
151                if input > *threshold {
152                    (magnitude * (input - threshold) / threshold).min(*saturation)
153                } else {
154                    0.0
155                }
156            }
157
158            TransferFunction::Decay {
159                initial,
160                decay_rate,
161            } => initial * (-decay_rate * input).exp(),
162
163            TransferFunction::Piecewise { points } => {
164                if points.is_empty() {
165                    return 0.0;
166                }
167                if points.len() == 1 {
168                    return points[0].1;
169                }
170
171                // Sort points by x for interpolation
172                let mut sorted = points.clone();
173                sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
174
175                // Clamp to range
176                if input <= sorted[0].0 {
177                    return sorted[0].1;
178                }
179                if input >= sorted[sorted.len() - 1].0 {
180                    return sorted[sorted.len() - 1].1;
181                }
182
183                // Linear interpolation
184                for window in sorted.windows(2) {
185                    let (x0, y0) = window[0];
186                    let (x1, y1) = window[1];
187                    if input >= x0 && input <= x1 {
188                        let t = (input - x0) / (x1 - x0);
189                        return y0 + t * (y1 - y0);
190                    }
191                }
192
193                sorted[sorted.len() - 1].1
194            }
195        }
196    }
197}
198
199/// Errors that can occur during CausalDAG operations.
200#[derive(Debug, Error)]
201pub enum CausalDAGError {
202    #[error("cycle detected in causal DAG")]
203    CycleDetected,
204    #[error("unknown node referenced in edge: {0}")]
205    UnknownNode(String),
206    #[error("duplicate node ID: {0}")]
207    DuplicateNode(String),
208    #[error("node '{0}' is not interventionable")]
209    NonInterventionable(String),
210}
211
212impl CausalDAG {
213    /// Validate the graph is a DAG (no cycles) and compute topological order.
214    pub fn validate(&mut self) -> Result<(), CausalDAGError> {
215        let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
216
217        // Check for duplicate IDs
218        let mut seen = HashSet::new();
219        for node in &self.nodes {
220            if !seen.insert(&node.id) {
221                return Err(CausalDAGError::DuplicateNode(node.id.clone()));
222            }
223        }
224
225        // Check for unknown nodes in edges
226        for edge in &self.edges {
227            if !node_ids.contains(edge.from.as_str()) {
228                return Err(CausalDAGError::UnknownNode(edge.from.clone()));
229            }
230            if !node_ids.contains(edge.to.as_str()) {
231                return Err(CausalDAGError::UnknownNode(edge.to.clone()));
232            }
233        }
234
235        // Kahn's algorithm for topological sort
236        let mut in_degree: HashMap<&str, usize> = HashMap::new();
237        let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
238
239        for node in &self.nodes {
240            in_degree.insert(&node.id, 0);
241            adjacency.insert(&node.id, Vec::new());
242        }
243
244        for edge in &self.edges {
245            *in_degree.entry(&edge.to).or_insert(0) += 1;
246            adjacency.entry(&edge.from).or_default().push(&edge.to);
247        }
248
249        let mut queue: VecDeque<&str> = VecDeque::new();
250        for (node, &degree) in &in_degree {
251            if degree == 0 {
252                queue.push_back(node);
253            }
254        }
255
256        let mut order = Vec::new();
257        while let Some(node) = queue.pop_front() {
258            order.push(node.to_string());
259            if let Some(neighbors) = adjacency.get(node) {
260                for &neighbor in neighbors {
261                    if let Some(degree) = in_degree.get_mut(neighbor) {
262                        *degree -= 1;
263                        if *degree == 0 {
264                            queue.push_back(neighbor);
265                        }
266                    }
267                }
268            }
269        }
270
271        if order.len() != self.nodes.len() {
272            return Err(CausalDAGError::CycleDetected);
273        }
274
275        self.topological_order = order;
276        Ok(())
277    }
278
279    /// Find a node by its ID.
280    pub fn find_node(&self, id: &str) -> Option<&CausalNode> {
281        self.nodes.iter().find(|n| n.id == id)
282    }
283
284    /// Given a set of interventions (node_id → new_value), propagate
285    /// effects through the DAG in topological order.
286    pub fn propagate(
287        &self,
288        interventions: &HashMap<String, f64>,
289        month: u32,
290    ) -> HashMap<String, f64> {
291        let mut values: HashMap<String, f64> = HashMap::new();
292
293        // Initialize all nodes with baseline values
294        for node in &self.nodes {
295            values.insert(node.id.clone(), node.baseline_value);
296        }
297
298        // Override with direct interventions
299        for (node_id, value) in interventions {
300            values.insert(node_id.clone(), *value);
301        }
302
303        // Build edge lookup: to_node -> list of (from_node, edge)
304        let mut incoming: HashMap<&str, Vec<&CausalEdge>> = HashMap::new();
305        for edge in &self.edges {
306            incoming.entry(&edge.to).or_default().push(edge);
307        }
308
309        // Propagate in topological order
310        for node_id in &self.topological_order {
311            // Skip nodes that are directly intervened upon
312            if interventions.contains_key(node_id) {
313                continue;
314            }
315
316            if let Some(edges) = incoming.get(node_id.as_str()) {
317                let mut total_effect = 0.0;
318                let mut has_effect = false;
319
320                for edge in edges {
321                    // Check lag: only apply if enough months have passed
322                    if month < edge.lag_months {
323                        continue;
324                    }
325
326                    let from_value = values.get(&edge.from).copied().unwrap_or(0.0);
327                    let baseline = self
328                        .find_node(&edge.from)
329                        .map(|n| n.baseline_value)
330                        .unwrap_or(0.0);
331
332                    // Compute the delta from baseline
333                    let delta = from_value - baseline;
334                    if delta.abs() < f64::EPSILON {
335                        continue;
336                    }
337
338                    // Apply transfer function to the delta
339                    let effect = edge.transfer.compute(delta) * edge.strength;
340                    total_effect += effect;
341                    has_effect = true;
342                }
343
344                if has_effect {
345                    let baseline = self
346                        .find_node(node_id)
347                        .map(|n| n.baseline_value)
348                        .unwrap_or(0.0);
349                    let mut new_value = baseline + total_effect;
350
351                    // Clamp to bounds
352                    if let Some(node) = self.find_node(node_id) {
353                        if let Some((min, max)) = node.bounds {
354                            new_value = new_value.clamp(min, max);
355                        }
356                    }
357
358                    values.insert(node_id.clone(), new_value);
359                }
360            }
361        }
362
363        values
364    }
365}
366
367#[cfg(test)]
368#[allow(clippy::unwrap_used)]
369mod tests {
370    use super::*;
371
372    fn make_node(id: &str, baseline: f64) -> CausalNode {
373        CausalNode {
374            id: id.to_string(),
375            label: id.to_string(),
376            category: NodeCategory::Operational,
377            baseline_value: baseline,
378            bounds: None,
379            interventionable: true,
380            config_bindings: vec![],
381        }
382    }
383
384    fn make_edge(from: &str, to: &str, transfer: TransferFunction) -> CausalEdge {
385        CausalEdge {
386            from: from.to_string(),
387            to: to.to_string(),
388            transfer,
389            lag_months: 0,
390            strength: 1.0,
391            mechanism: None,
392        }
393    }
394
395    #[test]
396    fn test_transfer_function_linear() {
397        let tf = TransferFunction::Linear {
398            coefficient: 0.5,
399            intercept: 1.0,
400        };
401        let result = tf.compute(2.0);
402        assert!((result - 2.0).abs() < f64::EPSILON); // 2.0 * 0.5 + 1.0 = 2.0
403    }
404
405    #[test]
406    fn test_transfer_function_logistic() {
407        let tf = TransferFunction::Logistic {
408            capacity: 1.0,
409            midpoint: 0.0,
410            steepness: 1.0,
411        };
412        // At midpoint, logistic returns capacity/2
413        let result = tf.compute(0.0);
414        assert!((result - 0.5).abs() < 0.001);
415    }
416
417    #[test]
418    fn test_transfer_function_exponential() {
419        let tf = TransferFunction::Exponential {
420            base: 1.0,
421            rate: 1.0,
422        };
423        // base * (1 + rate)^input = 1.0 * 2.0^3.0 = 8.0
424        let result = tf.compute(3.0);
425        assert!((result - 8.0).abs() < 0.001);
426    }
427
428    #[test]
429    fn test_transfer_function_step() {
430        let tf = TransferFunction::Step {
431            threshold: 5.0,
432            magnitude: 10.0,
433        };
434        assert!((tf.compute(3.0) - 0.0).abs() < f64::EPSILON);
435        assert!((tf.compute(6.0) - 10.0).abs() < f64::EPSILON);
436    }
437
438    #[test]
439    fn test_transfer_function_threshold() {
440        let tf = TransferFunction::Threshold {
441            threshold: 2.0,
442            magnitude: 10.0,
443            saturation: f64::INFINITY,
444        };
445        assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON); // below threshold
446                                                               // Above threshold: 10.0 * (3.0 - 2.0) / 2.0 = 5.0
447        assert!((tf.compute(3.0) - 5.0).abs() < 0.001);
448    }
449
450    #[test]
451    fn test_transfer_function_decay() {
452        let tf = TransferFunction::Decay {
453            initial: 100.0,
454            decay_rate: 0.5,
455        };
456        // At input=0: 100.0 * e^0 = 100.0
457        assert!((tf.compute(0.0) - 100.0).abs() < 0.001);
458        // At input=1: 100.0 * e^(-0.5) ≈ 60.65
459        assert!((tf.compute(1.0) - 60.653).abs() < 0.1);
460    }
461
462    #[test]
463    fn test_transfer_function_piecewise() {
464        let tf = TransferFunction::Piecewise {
465            points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0)],
466        };
467        // At 0.5: interpolate between (0,0) and (1,10) → 5.0
468        assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
469        // At 1.5: interpolate between (1,10) and (2,15) → 12.5
470        assert!((tf.compute(1.5) - 12.5).abs() < 0.001);
471        // Below range: clamp to first point
472        assert!((tf.compute(-1.0) - 0.0).abs() < 0.001);
473        // Above range: clamp to last point
474        assert!((tf.compute(3.0) - 15.0).abs() < 0.001);
475    }
476
477    #[test]
478    fn test_dag_validate_acyclic() {
479        let mut dag = CausalDAG {
480            nodes: vec![
481                make_node("a", 1.0),
482                make_node("b", 2.0),
483                make_node("c", 3.0),
484            ],
485            edges: vec![
486                make_edge(
487                    "a",
488                    "b",
489                    TransferFunction::Linear {
490                        coefficient: 1.0,
491                        intercept: 0.0,
492                    },
493                ),
494                make_edge(
495                    "b",
496                    "c",
497                    TransferFunction::Linear {
498                        coefficient: 1.0,
499                        intercept: 0.0,
500                    },
501                ),
502            ],
503            topological_order: vec![],
504        };
505        assert!(dag.validate().is_ok());
506        assert_eq!(dag.topological_order, vec!["a", "b", "c"]);
507    }
508
509    #[test]
510    fn test_dag_validate_cycle_detected() {
511        let mut dag = CausalDAG {
512            nodes: vec![make_node("a", 1.0), make_node("b", 2.0)],
513            edges: vec![
514                make_edge(
515                    "a",
516                    "b",
517                    TransferFunction::Linear {
518                        coefficient: 1.0,
519                        intercept: 0.0,
520                    },
521                ),
522                make_edge(
523                    "b",
524                    "a",
525                    TransferFunction::Linear {
526                        coefficient: 1.0,
527                        intercept: 0.0,
528                    },
529                ),
530            ],
531            topological_order: vec![],
532        };
533        assert!(matches!(dag.validate(), Err(CausalDAGError::CycleDetected)));
534    }
535
536    #[test]
537    fn test_dag_validate_unknown_node() {
538        let mut dag = CausalDAG {
539            nodes: vec![make_node("a", 1.0)],
540            edges: vec![make_edge(
541                "a",
542                "nonexistent",
543                TransferFunction::Linear {
544                    coefficient: 1.0,
545                    intercept: 0.0,
546                },
547            )],
548            topological_order: vec![],
549        };
550        assert!(matches!(
551            dag.validate(),
552            Err(CausalDAGError::UnknownNode(_))
553        ));
554    }
555
556    #[test]
557    fn test_dag_validate_duplicate_node() {
558        let mut dag = CausalDAG {
559            nodes: vec![make_node("a", 1.0), make_node("a", 2.0)],
560            edges: vec![],
561            topological_order: vec![],
562        };
563        assert!(matches!(
564            dag.validate(),
565            Err(CausalDAGError::DuplicateNode(_))
566        ));
567    }
568
569    #[test]
570    fn test_dag_propagate_chain() {
571        let mut dag = CausalDAG {
572            nodes: vec![
573                make_node("a", 10.0),
574                make_node("b", 5.0),
575                make_node("c", 0.0),
576            ],
577            edges: vec![
578                make_edge(
579                    "a",
580                    "b",
581                    TransferFunction::Linear {
582                        coefficient: 0.5,
583                        intercept: 0.0,
584                    },
585                ),
586                make_edge(
587                    "b",
588                    "c",
589                    TransferFunction::Linear {
590                        coefficient: 1.0,
591                        intercept: 0.0,
592                    },
593                ),
594            ],
595            topological_order: vec![],
596        };
597        dag.validate().unwrap();
598
599        // Intervene on A: set to 20.0 (delta = 10.0)
600        let mut interventions = HashMap::new();
601        interventions.insert("a".to_string(), 20.0);
602
603        let result = dag.propagate(&interventions, 0);
604        // A = 20.0 (directly set)
605        assert!((result["a"] - 20.0).abs() < 0.001);
606        // B baseline = 5.0, delta_a = 10.0, transfer = 10.0 * 0.5 + 0.0 = 5.0 → B = 5.0 + 5.0 = 10.0
607        assert!((result["b"] - 10.0).abs() < 0.001);
608        // C baseline = 0.0, delta_b = 5.0, transfer = 5.0 * 1.0 + 0.0 = 5.0 → C = 0.0 + 5.0 = 5.0
609        assert!((result["c"] - 5.0).abs() < 0.001);
610    }
611
612    #[test]
613    fn test_dag_propagate_with_lag() {
614        let mut dag = CausalDAG {
615            nodes: vec![make_node("a", 10.0), make_node("b", 5.0)],
616            edges: vec![CausalEdge {
617                from: "a".to_string(),
618                to: "b".to_string(),
619                transfer: TransferFunction::Linear {
620                    coefficient: 1.0,
621                    intercept: 0.0,
622                },
623                lag_months: 2,
624                strength: 1.0,
625                mechanism: None,
626            }],
627            topological_order: vec![],
628        };
629        dag.validate().unwrap();
630
631        let mut interventions = HashMap::new();
632        interventions.insert("a".to_string(), 20.0);
633
634        // Month 1: lag is 2, so no effect yet
635        let result = dag.propagate(&interventions, 1);
636        assert!((result["b"] - 5.0).abs() < 0.001); // unchanged from baseline
637
638        // Month 2: lag is met, effect propagates
639        let result = dag.propagate(&interventions, 2);
640        // delta_a = 10.0, transfer = 10.0, B = 5.0 + 10.0 = 15.0
641        assert!((result["b"] - 15.0).abs() < 0.001);
642    }
643
644    #[test]
645    fn test_dag_propagate_node_bounds_clamped() {
646        let mut dag = CausalDAG {
647            nodes: vec![make_node("a", 10.0), {
648                let mut n = make_node("b", 5.0);
649                n.bounds = Some((0.0, 8.0));
650                n
651            }],
652            edges: vec![make_edge(
653                "a",
654                "b",
655                TransferFunction::Linear {
656                    coefficient: 1.0,
657                    intercept: 0.0,
658                },
659            )],
660            topological_order: vec![],
661        };
662        dag.validate().unwrap();
663
664        let mut interventions = HashMap::new();
665        interventions.insert("a".to_string(), 20.0); // delta = 10.0 → B would be 15.0
666
667        let result = dag.propagate(&interventions, 0);
668        // B should be clamped to max bound of 8.0
669        assert!((result["b"] - 8.0).abs() < 0.001);
670    }
671
672    #[test]
673    fn test_transfer_function_serde() {
674        let tf = TransferFunction::Linear {
675            coefficient: 0.5,
676            intercept: 1.0,
677        };
678        let json = serde_json::to_string(&tf).unwrap();
679        let deserialized: TransferFunction = serde_json::from_str(&json).unwrap();
680        assert!((deserialized.compute(2.0) - 2.0).abs() < f64::EPSILON);
681    }
682
683    // ====================================================================
684    // Comprehensive transfer function tests (Task 12)
685    // ====================================================================
686
687    #[test]
688    fn test_transfer_function_linear_zero_coefficient() {
689        let tf = TransferFunction::Linear {
690            coefficient: 0.0,
691            intercept: 5.0,
692        };
693        // Any input → just the intercept
694        assert!((tf.compute(0.0) - 5.0).abs() < f64::EPSILON);
695        assert!((tf.compute(100.0) - 5.0).abs() < f64::EPSILON);
696        assert!((tf.compute(-100.0) - 5.0).abs() < f64::EPSILON);
697    }
698
699    #[test]
700    fn test_transfer_function_linear_negative_coefficient() {
701        let tf = TransferFunction::Linear {
702            coefficient: -2.0,
703            intercept: 10.0,
704        };
705        assert!((tf.compute(3.0) - 4.0).abs() < f64::EPSILON); // -6 + 10 = 4
706        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON); // -10 + 10 = 0
707    }
708
709    #[test]
710    fn test_transfer_function_exponential_zero_input() {
711        let tf = TransferFunction::Exponential {
712            base: 5.0,
713            rate: 0.5,
714        };
715        // (1+0.5)^0 = 1, so result = 5.0
716        assert!((tf.compute(0.0) - 5.0).abs() < 0.001);
717    }
718
719    #[test]
720    fn test_transfer_function_exponential_negative_rate() {
721        let tf = TransferFunction::Exponential {
722            base: 100.0,
723            rate: -0.5,
724        };
725        // (1 + (-0.5))^2 = 0.5^2 = 0.25, result = 25.0
726        assert!((tf.compute(2.0) - 25.0).abs() < 0.001);
727    }
728
729    #[test]
730    fn test_transfer_function_logistic_far_from_midpoint() {
731        let tf = TransferFunction::Logistic {
732            capacity: 10.0,
733            midpoint: 5.0,
734            steepness: 2.0,
735        };
736        // Far below midpoint → near 0
737        assert!(tf.compute(-10.0) < 0.01);
738        // Far above midpoint → near capacity
739        assert!((tf.compute(20.0) - 10.0).abs() < 0.01);
740        // At midpoint → capacity/2
741        assert!((tf.compute(5.0) - 5.0).abs() < 0.01);
742    }
743
744    #[test]
745    fn test_transfer_function_logistic_steepness_effect() {
746        // High steepness → sharper transition
747        let steep = TransferFunction::Logistic {
748            capacity: 1.0,
749            midpoint: 0.0,
750            steepness: 10.0,
751        };
752        let gentle = TransferFunction::Logistic {
753            capacity: 1.0,
754            midpoint: 0.0,
755            steepness: 0.5,
756        };
757        // Both should be ~0.5 at midpoint
758        assert!((steep.compute(0.0) - 0.5).abs() < 0.01);
759        assert!((gentle.compute(0.0) - 0.5).abs() < 0.01);
760        // Steep should be closer to 1.0 at input=1.0
761        assert!(steep.compute(1.0) > gentle.compute(1.0));
762    }
763
764    #[test]
765    fn test_transfer_function_inverse_logistic() {
766        let tf = TransferFunction::InverseLogistic {
767            capacity: 1.0,
768            midpoint: 0.0,
769            steepness: 1.0,
770        };
771        // At midpoint → capacity/2
772        assert!((tf.compute(0.0) - 0.5).abs() < 0.001);
773        // Inverse logistic decreases: far above midpoint → near 0
774        assert!(tf.compute(10.0) < 0.01);
775        // Far below midpoint → near capacity
776        assert!((tf.compute(-10.0) - 1.0).abs() < 0.01);
777    }
778
779    #[test]
780    fn test_transfer_function_inverse_logistic_symmetry() {
781        let logistic = TransferFunction::Logistic {
782            capacity: 1.0,
783            midpoint: 0.0,
784            steepness: 1.0,
785        };
786        let inverse = TransferFunction::InverseLogistic {
787            capacity: 1.0,
788            midpoint: 0.0,
789            steepness: 1.0,
790        };
791        // Logistic + InverseLogistic should sum to capacity at any point
792        for x in [-5.0, -1.0, 0.0, 1.0, 5.0] {
793            let sum = logistic.compute(x) + inverse.compute(x);
794            assert!((sum - 1.0).abs() < 0.001, "Sum at x={} was {}", x, sum);
795        }
796    }
797
798    #[test]
799    fn test_transfer_function_step_at_threshold() {
800        let tf = TransferFunction::Step {
801            threshold: 5.0,
802            magnitude: 10.0,
803        };
804        // At exactly the threshold, should be 0 (not strictly greater)
805        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
806        // Just above threshold
807        assert!((tf.compute(5.001) - 10.0).abs() < f64::EPSILON);
808    }
809
810    #[test]
811    fn test_transfer_function_step_negative_magnitude() {
812        let tf = TransferFunction::Step {
813            threshold: 0.0,
814            magnitude: -5.0,
815        };
816        assert!((tf.compute(-1.0) - 0.0).abs() < f64::EPSILON);
817        assert!((tf.compute(1.0) - (-5.0)).abs() < f64::EPSILON);
818    }
819
820    #[test]
821    fn test_transfer_function_threshold_with_saturation() {
822        let tf = TransferFunction::Threshold {
823            threshold: 2.0,
824            magnitude: 10.0,
825            saturation: 8.0,
826        };
827        // Below threshold: 0
828        assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON);
829        // Just above threshold: 10.0 * (2.5 - 2.0) / 2.0 = 2.5
830        assert!((tf.compute(2.5) - 2.5).abs() < 0.001);
831        // Way above threshold without saturation: 10.0 * (100.0 - 2.0) / 2.0 = 490
832        // But capped at saturation=8.0
833        assert!((tf.compute(100.0) - 8.0).abs() < 0.001);
834    }
835
836    #[test]
837    fn test_transfer_function_threshold_infinite_saturation() {
838        let tf = TransferFunction::Threshold {
839            threshold: 1.0,
840            magnitude: 5.0,
841            saturation: f64::INFINITY,
842        };
843        // No saturation cap: grows linearly
844        // 5.0 * (100.0 - 1.0) / 1.0 = 495.0
845        assert!((tf.compute(100.0) - 495.0).abs() < 0.001);
846    }
847
848    #[test]
849    fn test_transfer_function_decay_large_input() {
850        let tf = TransferFunction::Decay {
851            initial: 100.0,
852            decay_rate: 1.0,
853        };
854        // Large input → approaches 0
855        assert!(tf.compute(10.0) < 0.01);
856        assert!(tf.compute(20.0) < 0.0001);
857    }
858
859    #[test]
860    fn test_transfer_function_decay_zero_rate() {
861        let tf = TransferFunction::Decay {
862            initial: 50.0,
863            decay_rate: 0.0,
864        };
865        // No decay → constant
866        assert!((tf.compute(0.0) - 50.0).abs() < f64::EPSILON);
867        assert!((tf.compute(100.0) - 50.0).abs() < f64::EPSILON);
868    }
869
870    #[test]
871    fn test_transfer_function_piecewise_single_point() {
872        let tf = TransferFunction::Piecewise {
873            points: vec![(5.0, 42.0)],
874        };
875        // Single point → always returns that value
876        assert!((tf.compute(0.0) - 42.0).abs() < f64::EPSILON);
877        assert!((tf.compute(100.0) - 42.0).abs() < f64::EPSILON);
878    }
879
880    #[test]
881    fn test_transfer_function_piecewise_empty() {
882        let tf = TransferFunction::Piecewise { points: vec![] };
883        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
884    }
885
886    #[test]
887    fn test_transfer_function_piecewise_exact_points() {
888        let tf = TransferFunction::Piecewise {
889            points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0), (3.0, 30.0)],
890        };
891        // At exact breakpoints
892        assert!((tf.compute(0.0) - 0.0).abs() < 0.001);
893        assert!((tf.compute(1.0) - 10.0).abs() < 0.001);
894        assert!((tf.compute(2.0) - 15.0).abs() < 0.001);
895        assert!((tf.compute(3.0) - 30.0).abs() < 0.001);
896    }
897
898    #[test]
899    fn test_transfer_function_piecewise_unsorted_points() {
900        // Points given out of order — should still interpolate correctly
901        let tf = TransferFunction::Piecewise {
902            points: vec![(2.0, 20.0), (0.0, 0.0), (1.0, 10.0)],
903        };
904        assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
905        assert!((tf.compute(1.5) - 15.0).abs() < 0.001);
906    }
907}