Skip to main content

datasynth_core/models/
causal_dag.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3use thiserror::Error;
4
5/// A directed acyclic graph defining causal relationships
6/// between parameters in the generation model.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CausalDAG {
9    pub nodes: Vec<CausalNode>,
10    pub edges: Vec<CausalEdge>,
11    /// Pre-computed topological order (filled at validation time).
12    #[serde(skip)]
13    pub topological_order: Vec<String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalNode {
18    /// Unique identifier (matches config parameter path or abstract name).
19    pub id: String,
20    pub label: String,
21    pub category: NodeCategory,
22    /// Default/baseline value.
23    pub baseline_value: f64,
24    /// Valid range for this parameter.
25    pub bounds: Option<(f64, f64)>,
26    /// Whether this node can be directly intervened upon.
27    #[serde(default = "default_true")]
28    pub interventionable: bool,
29    /// Maps to config path(s) for actual generation parameters.
30    #[serde(default)]
31    pub config_bindings: Vec<String>,
32}
33
34fn default_true() -> bool {
35    true
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum NodeCategory {
41    Macro,
42    Operational,
43    Control,
44    Financial,
45    Behavioral,
46    Regulatory,
47    Outcome,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct CausalEdge {
52    pub from: String,
53    pub to: String,
54    pub transfer: TransferFunction,
55    /// Delay in months before the effect propagates.
56    #[serde(default)]
57    pub lag_months: u32,
58    /// Strength multiplier (0.0 = no effect, 1.0 = full transfer).
59    #[serde(default = "default_strength")]
60    pub strength: f64,
61    /// Human-readable description of the causal mechanism.
62    pub mechanism: Option<String>,
63}
64
65fn default_strength() -> f64 {
66    1.0
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum TransferFunction {
72    /// output = input * coefficient + intercept
73    Linear {
74        coefficient: f64,
75        #[serde(default)]
76        intercept: f64,
77    },
78    /// output = base * (1 + rate)^input
79    Exponential { base: f64, rate: f64 },
80    /// output = capacity / (1 + e^(-steepness * (input - midpoint)))
81    Logistic {
82        capacity: f64,
83        midpoint: f64,
84        steepness: f64,
85    },
86    /// output = capacity / (1 + e^(steepness * (input - midpoint)))
87    InverseLogistic {
88        capacity: f64,
89        midpoint: f64,
90        steepness: f64,
91    },
92    /// output = magnitude when input crosses threshold, else 0
93    Step { threshold: f64, magnitude: f64 },
94    /// output = magnitude when input > threshold, scaling linearly above
95    Threshold {
96        threshold: f64,
97        magnitude: f64,
98        #[serde(default = "default_saturation")]
99        saturation: f64,
100    },
101    /// output = initial * e^(-decay_rate * input)
102    Decay { initial: f64, decay_rate: f64 },
103    /// Lookup table with linear interpolation between points.
104    Piecewise { points: Vec<(f64, f64)> },
105}
106
107fn default_saturation() -> f64 {
108    f64::INFINITY
109}
110
111impl TransferFunction {
112    /// Compute the output value for a given input.
113    pub fn compute(&self, input: f64) -> f64 {
114        match self {
115            TransferFunction::Linear {
116                coefficient,
117                intercept,
118            } => input * coefficient + intercept,
119
120            TransferFunction::Exponential { base, rate } => base * (1.0 + rate).powf(input),
121
122            TransferFunction::Logistic {
123                capacity,
124                midpoint,
125                steepness,
126            } => capacity / (1.0 + (-steepness * (input - midpoint)).exp()),
127
128            TransferFunction::InverseLogistic {
129                capacity,
130                midpoint,
131                steepness,
132            } => capacity / (1.0 + (steepness * (input - midpoint)).exp()),
133
134            TransferFunction::Step {
135                threshold,
136                magnitude,
137            } => {
138                if input > *threshold {
139                    *magnitude
140                } else {
141                    0.0
142                }
143            }
144
145            TransferFunction::Threshold {
146                threshold,
147                magnitude,
148                saturation,
149            } => {
150                if input > *threshold {
151                    (magnitude * (input - threshold) / threshold).min(*saturation)
152                } else {
153                    0.0
154                }
155            }
156
157            TransferFunction::Decay {
158                initial,
159                decay_rate,
160            } => initial * (-decay_rate * input).exp(),
161
162            TransferFunction::Piecewise { points } => {
163                if points.is_empty() {
164                    return 0.0;
165                }
166                if points.len() == 1 {
167                    return points[0].1;
168                }
169
170                // Sort points by x for interpolation
171                let mut sorted = points.clone();
172                sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
173
174                // Clamp to range
175                if input <= sorted[0].0 {
176                    return sorted[0].1;
177                }
178                if input >= sorted[sorted.len() - 1].0 {
179                    return sorted[sorted.len() - 1].1;
180                }
181
182                // Linear interpolation
183                for window in sorted.windows(2) {
184                    let (x0, y0) = window[0];
185                    let (x1, y1) = window[1];
186                    if input >= x0 && input <= x1 {
187                        let t = (input - x0) / (x1 - x0);
188                        return y0 + t * (y1 - y0);
189                    }
190                }
191
192                sorted[sorted.len() - 1].1
193            }
194        }
195    }
196}
197
198/// Errors that can occur during CausalDAG operations.
199#[derive(Debug, Error)]
200pub enum CausalDAGError {
201    #[error("cycle detected in causal DAG")]
202    CycleDetected,
203    #[error("unknown node referenced in edge: {0}")]
204    UnknownNode(String),
205    #[error("duplicate node ID: {0}")]
206    DuplicateNode(String),
207    #[error("node '{0}' is not interventionable")]
208    NonInterventionable(String),
209}
210
211impl CausalDAG {
212    /// Validate the graph is a DAG (no cycles) and compute topological order.
213    pub fn validate(&mut self) -> Result<(), CausalDAGError> {
214        let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
215
216        // Check for duplicate IDs
217        let mut seen = HashSet::new();
218        for node in &self.nodes {
219            if !seen.insert(&node.id) {
220                return Err(CausalDAGError::DuplicateNode(node.id.clone()));
221            }
222        }
223
224        // Check for unknown nodes in edges
225        for edge in &self.edges {
226            if !node_ids.contains(edge.from.as_str()) {
227                return Err(CausalDAGError::UnknownNode(edge.from.clone()));
228            }
229            if !node_ids.contains(edge.to.as_str()) {
230                return Err(CausalDAGError::UnknownNode(edge.to.clone()));
231            }
232        }
233
234        // Kahn's algorithm for topological sort
235        let mut in_degree: HashMap<&str, usize> = HashMap::new();
236        let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
237
238        for node in &self.nodes {
239            in_degree.insert(&node.id, 0);
240            adjacency.insert(&node.id, Vec::new());
241        }
242
243        for edge in &self.edges {
244            *in_degree.entry(&edge.to).or_insert(0) += 1;
245            adjacency.entry(&edge.from).or_default().push(&edge.to);
246        }
247
248        let mut queue: VecDeque<&str> = VecDeque::new();
249        for (node, &degree) in &in_degree {
250            if degree == 0 {
251                queue.push_back(node);
252            }
253        }
254
255        let mut order = Vec::new();
256        while let Some(node) = queue.pop_front() {
257            order.push(node.to_string());
258            if let Some(neighbors) = adjacency.get(node) {
259                for &neighbor in neighbors {
260                    if let Some(degree) = in_degree.get_mut(neighbor) {
261                        *degree -= 1;
262                        if *degree == 0 {
263                            queue.push_back(neighbor);
264                        }
265                    }
266                }
267            }
268        }
269
270        if order.len() != self.nodes.len() {
271            return Err(CausalDAGError::CycleDetected);
272        }
273
274        self.topological_order = order;
275        Ok(())
276    }
277
278    /// Find a node by its ID.
279    pub fn find_node(&self, id: &str) -> Option<&CausalNode> {
280        self.nodes.iter().find(|n| n.id == id)
281    }
282
283    /// Given a set of interventions (node_id → new_value), propagate
284    /// effects through the DAG in topological order.
285    pub fn propagate(
286        &self,
287        interventions: &HashMap<String, f64>,
288        month: u32,
289    ) -> HashMap<String, f64> {
290        let mut values: HashMap<String, f64> = HashMap::new();
291
292        // Initialize all nodes with baseline values
293        for node in &self.nodes {
294            values.insert(node.id.clone(), node.baseline_value);
295        }
296
297        // Override with direct interventions
298        for (node_id, value) in interventions {
299            values.insert(node_id.clone(), *value);
300        }
301
302        // Build edge lookup: to_node -> list of (from_node, edge)
303        let mut incoming: HashMap<&str, Vec<&CausalEdge>> = HashMap::new();
304        for edge in &self.edges {
305            incoming.entry(&edge.to).or_default().push(edge);
306        }
307
308        // Propagate in topological order
309        for node_id in &self.topological_order {
310            // Skip nodes that are directly intervened upon
311            if interventions.contains_key(node_id) {
312                continue;
313            }
314
315            if let Some(edges) = incoming.get(node_id.as_str()) {
316                let mut total_effect = 0.0;
317                let mut has_effect = false;
318
319                for edge in edges {
320                    // Check lag: only apply if enough months have passed
321                    if month < edge.lag_months {
322                        continue;
323                    }
324
325                    let from_value = values.get(&edge.from).copied().unwrap_or(0.0);
326                    let baseline = self
327                        .find_node(&edge.from)
328                        .map(|n| n.baseline_value)
329                        .unwrap_or(0.0);
330
331                    // Compute the delta from baseline
332                    let delta = from_value - baseline;
333                    if delta.abs() < f64::EPSILON {
334                        continue;
335                    }
336
337                    // Apply transfer function to the delta
338                    let effect = edge.transfer.compute(delta) * edge.strength;
339                    total_effect += effect;
340                    has_effect = true;
341                }
342
343                if has_effect {
344                    let baseline = self
345                        .find_node(node_id)
346                        .map(|n| n.baseline_value)
347                        .unwrap_or(0.0);
348                    let mut new_value = baseline + total_effect;
349
350                    // Clamp to bounds
351                    if let Some(node) = self.find_node(node_id) {
352                        if let Some((min, max)) = node.bounds {
353                            new_value = new_value.clamp(min, max);
354                        }
355                    }
356
357                    values.insert(node_id.clone(), new_value);
358                }
359            }
360        }
361
362        values
363    }
364}
365
366#[cfg(test)]
367#[allow(clippy::unwrap_used)]
368mod tests {
369    use super::*;
370
371    fn make_node(id: &str, baseline: f64) -> CausalNode {
372        CausalNode {
373            id: id.to_string(),
374            label: id.to_string(),
375            category: NodeCategory::Operational,
376            baseline_value: baseline,
377            bounds: None,
378            interventionable: true,
379            config_bindings: vec![],
380        }
381    }
382
383    fn make_edge(from: &str, to: &str, transfer: TransferFunction) -> CausalEdge {
384        CausalEdge {
385            from: from.to_string(),
386            to: to.to_string(),
387            transfer,
388            lag_months: 0,
389            strength: 1.0,
390            mechanism: None,
391        }
392    }
393
394    #[test]
395    fn test_transfer_function_linear() {
396        let tf = TransferFunction::Linear {
397            coefficient: 0.5,
398            intercept: 1.0,
399        };
400        let result = tf.compute(2.0);
401        assert!((result - 2.0).abs() < f64::EPSILON); // 2.0 * 0.5 + 1.0 = 2.0
402    }
403
404    #[test]
405    fn test_transfer_function_logistic() {
406        let tf = TransferFunction::Logistic {
407            capacity: 1.0,
408            midpoint: 0.0,
409            steepness: 1.0,
410        };
411        // At midpoint, logistic returns capacity/2
412        let result = tf.compute(0.0);
413        assert!((result - 0.5).abs() < 0.001);
414    }
415
416    #[test]
417    fn test_transfer_function_exponential() {
418        let tf = TransferFunction::Exponential {
419            base: 1.0,
420            rate: 1.0,
421        };
422        // base * (1 + rate)^input = 1.0 * 2.0^3.0 = 8.0
423        let result = tf.compute(3.0);
424        assert!((result - 8.0).abs() < 0.001);
425    }
426
427    #[test]
428    fn test_transfer_function_step() {
429        let tf = TransferFunction::Step {
430            threshold: 5.0,
431            magnitude: 10.0,
432        };
433        assert!((tf.compute(3.0) - 0.0).abs() < f64::EPSILON);
434        assert!((tf.compute(6.0) - 10.0).abs() < f64::EPSILON);
435    }
436
437    #[test]
438    fn test_transfer_function_threshold() {
439        let tf = TransferFunction::Threshold {
440            threshold: 2.0,
441            magnitude: 10.0,
442            saturation: f64::INFINITY,
443        };
444        assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON); // below threshold
445                                                               // Above threshold: 10.0 * (3.0 - 2.0) / 2.0 = 5.0
446        assert!((tf.compute(3.0) - 5.0).abs() < 0.001);
447    }
448
449    #[test]
450    fn test_transfer_function_decay() {
451        let tf = TransferFunction::Decay {
452            initial: 100.0,
453            decay_rate: 0.5,
454        };
455        // At input=0: 100.0 * e^0 = 100.0
456        assert!((tf.compute(0.0) - 100.0).abs() < 0.001);
457        // At input=1: 100.0 * e^(-0.5) ≈ 60.65
458        assert!((tf.compute(1.0) - 60.653).abs() < 0.1);
459    }
460
461    #[test]
462    fn test_transfer_function_piecewise() {
463        let tf = TransferFunction::Piecewise {
464            points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0)],
465        };
466        // At 0.5: interpolate between (0,0) and (1,10) → 5.0
467        assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
468        // At 1.5: interpolate between (1,10) and (2,15) → 12.5
469        assert!((tf.compute(1.5) - 12.5).abs() < 0.001);
470        // Below range: clamp to first point
471        assert!((tf.compute(-1.0) - 0.0).abs() < 0.001);
472        // Above range: clamp to last point
473        assert!((tf.compute(3.0) - 15.0).abs() < 0.001);
474    }
475
476    #[test]
477    fn test_dag_validate_acyclic() {
478        let mut dag = CausalDAG {
479            nodes: vec![
480                make_node("a", 1.0),
481                make_node("b", 2.0),
482                make_node("c", 3.0),
483            ],
484            edges: vec![
485                make_edge(
486                    "a",
487                    "b",
488                    TransferFunction::Linear {
489                        coefficient: 1.0,
490                        intercept: 0.0,
491                    },
492                ),
493                make_edge(
494                    "b",
495                    "c",
496                    TransferFunction::Linear {
497                        coefficient: 1.0,
498                        intercept: 0.0,
499                    },
500                ),
501            ],
502            topological_order: vec![],
503        };
504        assert!(dag.validate().is_ok());
505        assert_eq!(dag.topological_order, vec!["a", "b", "c"]);
506    }
507
508    #[test]
509    fn test_dag_validate_cycle_detected() {
510        let mut dag = CausalDAG {
511            nodes: vec![make_node("a", 1.0), make_node("b", 2.0)],
512            edges: vec![
513                make_edge(
514                    "a",
515                    "b",
516                    TransferFunction::Linear {
517                        coefficient: 1.0,
518                        intercept: 0.0,
519                    },
520                ),
521                make_edge(
522                    "b",
523                    "a",
524                    TransferFunction::Linear {
525                        coefficient: 1.0,
526                        intercept: 0.0,
527                    },
528                ),
529            ],
530            topological_order: vec![],
531        };
532        assert!(matches!(dag.validate(), Err(CausalDAGError::CycleDetected)));
533    }
534
535    #[test]
536    fn test_dag_validate_unknown_node() {
537        let mut dag = CausalDAG {
538            nodes: vec![make_node("a", 1.0)],
539            edges: vec![make_edge(
540                "a",
541                "nonexistent",
542                TransferFunction::Linear {
543                    coefficient: 1.0,
544                    intercept: 0.0,
545                },
546            )],
547            topological_order: vec![],
548        };
549        assert!(matches!(
550            dag.validate(),
551            Err(CausalDAGError::UnknownNode(_))
552        ));
553    }
554
555    #[test]
556    fn test_dag_validate_duplicate_node() {
557        let mut dag = CausalDAG {
558            nodes: vec![make_node("a", 1.0), make_node("a", 2.0)],
559            edges: vec![],
560            topological_order: vec![],
561        };
562        assert!(matches!(
563            dag.validate(),
564            Err(CausalDAGError::DuplicateNode(_))
565        ));
566    }
567
568    #[test]
569    fn test_dag_propagate_chain() {
570        let mut dag = CausalDAG {
571            nodes: vec![
572                make_node("a", 10.0),
573                make_node("b", 5.0),
574                make_node("c", 0.0),
575            ],
576            edges: vec![
577                make_edge(
578                    "a",
579                    "b",
580                    TransferFunction::Linear {
581                        coefficient: 0.5,
582                        intercept: 0.0,
583                    },
584                ),
585                make_edge(
586                    "b",
587                    "c",
588                    TransferFunction::Linear {
589                        coefficient: 1.0,
590                        intercept: 0.0,
591                    },
592                ),
593            ],
594            topological_order: vec![],
595        };
596        dag.validate().unwrap();
597
598        // Intervene on A: set to 20.0 (delta = 10.0)
599        let mut interventions = HashMap::new();
600        interventions.insert("a".to_string(), 20.0);
601
602        let result = dag.propagate(&interventions, 0);
603        // A = 20.0 (directly set)
604        assert!((result["a"] - 20.0).abs() < 0.001);
605        // B baseline = 5.0, delta_a = 10.0, transfer = 10.0 * 0.5 + 0.0 = 5.0 → B = 5.0 + 5.0 = 10.0
606        assert!((result["b"] - 10.0).abs() < 0.001);
607        // C baseline = 0.0, delta_b = 5.0, transfer = 5.0 * 1.0 + 0.0 = 5.0 → C = 0.0 + 5.0 = 5.0
608        assert!((result["c"] - 5.0).abs() < 0.001);
609    }
610
611    #[test]
612    fn test_dag_propagate_with_lag() {
613        let mut dag = CausalDAG {
614            nodes: vec![make_node("a", 10.0), make_node("b", 5.0)],
615            edges: vec![CausalEdge {
616                from: "a".to_string(),
617                to: "b".to_string(),
618                transfer: TransferFunction::Linear {
619                    coefficient: 1.0,
620                    intercept: 0.0,
621                },
622                lag_months: 2,
623                strength: 1.0,
624                mechanism: None,
625            }],
626            topological_order: vec![],
627        };
628        dag.validate().unwrap();
629
630        let mut interventions = HashMap::new();
631        interventions.insert("a".to_string(), 20.0);
632
633        // Month 1: lag is 2, so no effect yet
634        let result = dag.propagate(&interventions, 1);
635        assert!((result["b"] - 5.0).abs() < 0.001); // unchanged from baseline
636
637        // Month 2: lag is met, effect propagates
638        let result = dag.propagate(&interventions, 2);
639        // delta_a = 10.0, transfer = 10.0, B = 5.0 + 10.0 = 15.0
640        assert!((result["b"] - 15.0).abs() < 0.001);
641    }
642
643    #[test]
644    fn test_dag_propagate_node_bounds_clamped() {
645        let mut dag = CausalDAG {
646            nodes: vec![make_node("a", 10.0), {
647                let mut n = make_node("b", 5.0);
648                n.bounds = Some((0.0, 8.0));
649                n
650            }],
651            edges: vec![make_edge(
652                "a",
653                "b",
654                TransferFunction::Linear {
655                    coefficient: 1.0,
656                    intercept: 0.0,
657                },
658            )],
659            topological_order: vec![],
660        };
661        dag.validate().unwrap();
662
663        let mut interventions = HashMap::new();
664        interventions.insert("a".to_string(), 20.0); // delta = 10.0 → B would be 15.0
665
666        let result = dag.propagate(&interventions, 0);
667        // B should be clamped to max bound of 8.0
668        assert!((result["b"] - 8.0).abs() < 0.001);
669    }
670
671    #[test]
672    fn test_transfer_function_serde() {
673        let tf = TransferFunction::Linear {
674            coefficient: 0.5,
675            intercept: 1.0,
676        };
677        let json = serde_json::to_string(&tf).unwrap();
678        let deserialized: TransferFunction = serde_json::from_str(&json).unwrap();
679        assert!((deserialized.compute(2.0) - 2.0).abs() < f64::EPSILON);
680    }
681
682    // ====================================================================
683    // Comprehensive transfer function tests (Task 12)
684    // ====================================================================
685
686    #[test]
687    fn test_transfer_function_linear_zero_coefficient() {
688        let tf = TransferFunction::Linear {
689            coefficient: 0.0,
690            intercept: 5.0,
691        };
692        // Any input → just the intercept
693        assert!((tf.compute(0.0) - 5.0).abs() < f64::EPSILON);
694        assert!((tf.compute(100.0) - 5.0).abs() < f64::EPSILON);
695        assert!((tf.compute(-100.0) - 5.0).abs() < f64::EPSILON);
696    }
697
698    #[test]
699    fn test_transfer_function_linear_negative_coefficient() {
700        let tf = TransferFunction::Linear {
701            coefficient: -2.0,
702            intercept: 10.0,
703        };
704        assert!((tf.compute(3.0) - 4.0).abs() < f64::EPSILON); // -6 + 10 = 4
705        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON); // -10 + 10 = 0
706    }
707
708    #[test]
709    fn test_transfer_function_exponential_zero_input() {
710        let tf = TransferFunction::Exponential {
711            base: 5.0,
712            rate: 0.5,
713        };
714        // (1+0.5)^0 = 1, so result = 5.0
715        assert!((tf.compute(0.0) - 5.0).abs() < 0.001);
716    }
717
718    #[test]
719    fn test_transfer_function_exponential_negative_rate() {
720        let tf = TransferFunction::Exponential {
721            base: 100.0,
722            rate: -0.5,
723        };
724        // (1 + (-0.5))^2 = 0.5^2 = 0.25, result = 25.0
725        assert!((tf.compute(2.0) - 25.0).abs() < 0.001);
726    }
727
728    #[test]
729    fn test_transfer_function_logistic_far_from_midpoint() {
730        let tf = TransferFunction::Logistic {
731            capacity: 10.0,
732            midpoint: 5.0,
733            steepness: 2.0,
734        };
735        // Far below midpoint → near 0
736        assert!(tf.compute(-10.0) < 0.01);
737        // Far above midpoint → near capacity
738        assert!((tf.compute(20.0) - 10.0).abs() < 0.01);
739        // At midpoint → capacity/2
740        assert!((tf.compute(5.0) - 5.0).abs() < 0.01);
741    }
742
743    #[test]
744    fn test_transfer_function_logistic_steepness_effect() {
745        // High steepness → sharper transition
746        let steep = TransferFunction::Logistic {
747            capacity: 1.0,
748            midpoint: 0.0,
749            steepness: 10.0,
750        };
751        let gentle = TransferFunction::Logistic {
752            capacity: 1.0,
753            midpoint: 0.0,
754            steepness: 0.5,
755        };
756        // Both should be ~0.5 at midpoint
757        assert!((steep.compute(0.0) - 0.5).abs() < 0.01);
758        assert!((gentle.compute(0.0) - 0.5).abs() < 0.01);
759        // Steep should be closer to 1.0 at input=1.0
760        assert!(steep.compute(1.0) > gentle.compute(1.0));
761    }
762
763    #[test]
764    fn test_transfer_function_inverse_logistic() {
765        let tf = TransferFunction::InverseLogistic {
766            capacity: 1.0,
767            midpoint: 0.0,
768            steepness: 1.0,
769        };
770        // At midpoint → capacity/2
771        assert!((tf.compute(0.0) - 0.5).abs() < 0.001);
772        // Inverse logistic decreases: far above midpoint → near 0
773        assert!(tf.compute(10.0) < 0.01);
774        // Far below midpoint → near capacity
775        assert!((tf.compute(-10.0) - 1.0).abs() < 0.01);
776    }
777
778    #[test]
779    fn test_transfer_function_inverse_logistic_symmetry() {
780        let logistic = TransferFunction::Logistic {
781            capacity: 1.0,
782            midpoint: 0.0,
783            steepness: 1.0,
784        };
785        let inverse = TransferFunction::InverseLogistic {
786            capacity: 1.0,
787            midpoint: 0.0,
788            steepness: 1.0,
789        };
790        // Logistic + InverseLogistic should sum to capacity at any point
791        for x in [-5.0, -1.0, 0.0, 1.0, 5.0] {
792            let sum = logistic.compute(x) + inverse.compute(x);
793            assert!((sum - 1.0).abs() < 0.001, "Sum at x={} was {}", x, sum);
794        }
795    }
796
797    #[test]
798    fn test_transfer_function_step_at_threshold() {
799        let tf = TransferFunction::Step {
800            threshold: 5.0,
801            magnitude: 10.0,
802        };
803        // At exactly the threshold, should be 0 (not strictly greater)
804        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
805        // Just above threshold
806        assert!((tf.compute(5.001) - 10.0).abs() < f64::EPSILON);
807    }
808
809    #[test]
810    fn test_transfer_function_step_negative_magnitude() {
811        let tf = TransferFunction::Step {
812            threshold: 0.0,
813            magnitude: -5.0,
814        };
815        assert!((tf.compute(-1.0) - 0.0).abs() < f64::EPSILON);
816        assert!((tf.compute(1.0) - (-5.0)).abs() < f64::EPSILON);
817    }
818
819    #[test]
820    fn test_transfer_function_threshold_with_saturation() {
821        let tf = TransferFunction::Threshold {
822            threshold: 2.0,
823            magnitude: 10.0,
824            saturation: 8.0,
825        };
826        // Below threshold: 0
827        assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON);
828        // Just above threshold: 10.0 * (2.5 - 2.0) / 2.0 = 2.5
829        assert!((tf.compute(2.5) - 2.5).abs() < 0.001);
830        // Way above threshold without saturation: 10.0 * (100.0 - 2.0) / 2.0 = 490
831        // But capped at saturation=8.0
832        assert!((tf.compute(100.0) - 8.0).abs() < 0.001);
833    }
834
835    #[test]
836    fn test_transfer_function_threshold_infinite_saturation() {
837        let tf = TransferFunction::Threshold {
838            threshold: 1.0,
839            magnitude: 5.0,
840            saturation: f64::INFINITY,
841        };
842        // No saturation cap: grows linearly
843        // 5.0 * (100.0 - 1.0) / 1.0 = 495.0
844        assert!((tf.compute(100.0) - 495.0).abs() < 0.001);
845    }
846
847    #[test]
848    fn test_transfer_function_decay_large_input() {
849        let tf = TransferFunction::Decay {
850            initial: 100.0,
851            decay_rate: 1.0,
852        };
853        // Large input → approaches 0
854        assert!(tf.compute(10.0) < 0.01);
855        assert!(tf.compute(20.0) < 0.0001);
856    }
857
858    #[test]
859    fn test_transfer_function_decay_zero_rate() {
860        let tf = TransferFunction::Decay {
861            initial: 50.0,
862            decay_rate: 0.0,
863        };
864        // No decay → constant
865        assert!((tf.compute(0.0) - 50.0).abs() < f64::EPSILON);
866        assert!((tf.compute(100.0) - 50.0).abs() < f64::EPSILON);
867    }
868
869    #[test]
870    fn test_transfer_function_piecewise_single_point() {
871        let tf = TransferFunction::Piecewise {
872            points: vec![(5.0, 42.0)],
873        };
874        // Single point → always returns that value
875        assert!((tf.compute(0.0) - 42.0).abs() < f64::EPSILON);
876        assert!((tf.compute(100.0) - 42.0).abs() < f64::EPSILON);
877    }
878
879    #[test]
880    fn test_transfer_function_piecewise_empty() {
881        let tf = TransferFunction::Piecewise { points: vec![] };
882        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
883    }
884
885    #[test]
886    fn test_transfer_function_piecewise_exact_points() {
887        let tf = TransferFunction::Piecewise {
888            points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0), (3.0, 30.0)],
889        };
890        // At exact breakpoints
891        assert!((tf.compute(0.0) - 0.0).abs() < 0.001);
892        assert!((tf.compute(1.0) - 10.0).abs() < 0.001);
893        assert!((tf.compute(2.0) - 15.0).abs() < 0.001);
894        assert!((tf.compute(3.0) - 30.0).abs() < 0.001);
895    }
896
897    #[test]
898    fn test_transfer_function_piecewise_unsorted_points() {
899        // Points given out of order — should still interpolate correctly
900        let tf = TransferFunction::Piecewise {
901            points: vec![(2.0, 20.0), (0.0, 0.0), (1.0, 10.0)],
902        };
903        assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
904        assert!((tf.compute(1.5) - 15.0).abs() < 0.001);
905    }
906}