Skip to main content

datasynth_core/models/
causal_dag.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3use thiserror::Error;
4
5/// A directed acyclic graph defining causal relationships
6/// between parameters in the generation model.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CausalDAG {
9    pub nodes: Vec<CausalNode>,
10    pub edges: Vec<CausalEdge>,
11    /// Pre-computed topological order (filled at validation time).
12    #[serde(skip)]
13    pub topological_order: Vec<String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalNode {
18    /// Unique identifier (matches config parameter path or abstract name).
19    pub id: String,
20    pub label: String,
21    pub category: NodeCategory,
22    /// Default/baseline value.
23    pub baseline_value: f64,
24    /// Valid range for this parameter.
25    pub bounds: Option<(f64, f64)>,
26    /// Whether this node can be directly intervened upon.
27    #[serde(default = "default_true")]
28    pub interventionable: bool,
29    /// Maps to config path(s) for actual generation parameters.
30    #[serde(default)]
31    pub config_bindings: Vec<String>,
32}
33
34fn default_true() -> bool {
35    true
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum NodeCategory {
41    Macro,
42    Operational,
43    Control,
44    Financial,
45    Behavioral,
46    Regulatory,
47    Outcome,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct CausalEdge {
52    pub from: String,
53    pub to: String,
54    pub transfer: TransferFunction,
55    /// Delay in months before the effect propagates.
56    #[serde(default)]
57    pub lag_months: u32,
58    /// Strength multiplier (0.0 = no effect, 1.0 = full transfer).
59    #[serde(default = "default_strength")]
60    pub strength: f64,
61    /// Human-readable description of the causal mechanism.
62    pub mechanism: Option<String>,
63}
64
65fn default_strength() -> f64 {
66    1.0
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum TransferFunction {
72    /// output = input * coefficient + intercept
73    Linear {
74        coefficient: f64,
75        #[serde(default)]
76        intercept: f64,
77    },
78    /// output = base * (1 + rate)^input
79    Exponential { base: f64, rate: f64 },
80    /// output = capacity / (1 + e^(-steepness * (input - midpoint)))
81    Logistic {
82        capacity: f64,
83        midpoint: f64,
84        steepness: f64,
85    },
86    /// output = capacity / (1 + e^(steepness * (input - midpoint)))
87    InverseLogistic {
88        capacity: f64,
89        midpoint: f64,
90        steepness: f64,
91    },
92    /// output = magnitude when input crosses threshold, else 0
93    Step { threshold: f64, magnitude: f64 },
94    /// output = magnitude when input > threshold, scaling linearly above
95    Threshold {
96        threshold: f64,
97        magnitude: f64,
98        #[serde(default = "default_saturation")]
99        saturation: f64,
100    },
101    /// output = initial * e^(-decay_rate * input)
102    Decay { initial: f64, decay_rate: f64 },
103    /// Lookup table with linear interpolation between points.
104    Piecewise { points: Vec<(f64, f64)> },
105}
106
107fn default_saturation() -> f64 {
108    f64::INFINITY
109}
110
111impl TransferFunction {
112    /// Compute the output value for a given input.
113    pub fn compute(&self, input: f64) -> f64 {
114        match self {
115            TransferFunction::Linear {
116                coefficient,
117                intercept,
118            } => input * coefficient + intercept,
119
120            TransferFunction::Exponential { base, rate } => base * (1.0 + rate).powf(input),
121
122            TransferFunction::Logistic {
123                capacity,
124                midpoint,
125                steepness,
126            } => capacity / (1.0 + (-steepness * (input - midpoint)).exp()),
127
128            TransferFunction::InverseLogistic {
129                capacity,
130                midpoint,
131                steepness,
132            } => capacity / (1.0 + (steepness * (input - midpoint)).exp()),
133
134            TransferFunction::Step {
135                threshold,
136                magnitude,
137            } => {
138                if input > *threshold {
139                    *magnitude
140                } else {
141                    0.0
142                }
143            }
144
145            TransferFunction::Threshold {
146                threshold,
147                magnitude,
148                saturation,
149            } => {
150                if input > *threshold {
151                    (magnitude * (input - threshold) / threshold).min(*saturation)
152                } else {
153                    0.0
154                }
155            }
156
157            TransferFunction::Decay {
158                initial,
159                decay_rate,
160            } => initial * (-decay_rate * input).exp(),
161
162            TransferFunction::Piecewise { points } => {
163                if points.is_empty() {
164                    return 0.0;
165                }
166                if points.len() == 1 {
167                    return points[0].1;
168                }
169
170                // Sort points by x for interpolation
171                let mut sorted = points.clone();
172                sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
173
174                // Clamp to range
175                if input <= sorted[0].0 {
176                    return sorted[0].1;
177                }
178                if input >= sorted[sorted.len() - 1].0 {
179                    return sorted[sorted.len() - 1].1;
180                }
181
182                // Linear interpolation
183                for window in sorted.windows(2) {
184                    let (x0, y0) = window[0];
185                    let (x1, y1) = window[1];
186                    if input >= x0 && input <= x1 {
187                        let t = (input - x0) / (x1 - x0);
188                        return y0 + t * (y1 - y0);
189                    }
190                }
191
192                sorted[sorted.len() - 1].1
193            }
194        }
195    }
196}
197
198/// Errors that can occur during CausalDAG operations.
199#[derive(Debug, Error)]
200pub enum CausalDAGError {
201    #[error("cycle detected in causal DAG")]
202    CycleDetected,
203    #[error("unknown node referenced in edge: {0}")]
204    UnknownNode(String),
205    #[error("duplicate node ID: {0}")]
206    DuplicateNode(String),
207    #[error("node '{0}' is not interventionable")]
208    NonInterventionable(String),
209}
210
211impl CausalDAG {
212    /// Validate the graph is a DAG (no cycles) and compute topological order.
213    pub fn validate(&mut self) -> Result<(), CausalDAGError> {
214        let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
215
216        // Check for duplicate IDs
217        let mut seen = HashSet::new();
218        for node in &self.nodes {
219            if !seen.insert(&node.id) {
220                return Err(CausalDAGError::DuplicateNode(node.id.clone()));
221            }
222        }
223
224        // Check for unknown nodes in edges
225        for edge in &self.edges {
226            if !node_ids.contains(edge.from.as_str()) {
227                return Err(CausalDAGError::UnknownNode(edge.from.clone()));
228            }
229            if !node_ids.contains(edge.to.as_str()) {
230                return Err(CausalDAGError::UnknownNode(edge.to.clone()));
231            }
232        }
233
234        // Kahn's algorithm for topological sort
235        let mut in_degree: HashMap<&str, usize> = HashMap::new();
236        let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
237
238        for node in &self.nodes {
239            in_degree.insert(&node.id, 0);
240            adjacency.insert(&node.id, Vec::new());
241        }
242
243        for edge in &self.edges {
244            *in_degree.entry(&edge.to).or_insert(0) += 1;
245            adjacency.entry(&edge.from).or_default().push(&edge.to);
246        }
247
248        let mut queue: VecDeque<&str> = VecDeque::new();
249        for (node, &degree) in &in_degree {
250            if degree == 0 {
251                queue.push_back(node);
252            }
253        }
254
255        let mut order = Vec::new();
256        while let Some(node) = queue.pop_front() {
257            order.push(node.to_string());
258            if let Some(neighbors) = adjacency.get(node) {
259                for &neighbor in neighbors {
260                    if let Some(degree) = in_degree.get_mut(neighbor) {
261                        *degree -= 1;
262                        if *degree == 0 {
263                            queue.push_back(neighbor);
264                        }
265                    }
266                }
267            }
268        }
269
270        if order.len() != self.nodes.len() {
271            return Err(CausalDAGError::CycleDetected);
272        }
273
274        self.topological_order = order;
275        Ok(())
276    }
277
278    /// Find a node by its ID.
279    pub fn find_node(&self, id: &str) -> Option<&CausalNode> {
280        self.nodes.iter().find(|n| n.id == id)
281    }
282
283    /// Given a set of interventions (node_id → new_value), propagate
284    /// effects through the DAG in topological order.
285    pub fn propagate(
286        &self,
287        interventions: &HashMap<String, f64>,
288        month: u32,
289    ) -> HashMap<String, f64> {
290        let mut values: HashMap<String, f64> = HashMap::new();
291
292        // Initialize all nodes with baseline values
293        for node in &self.nodes {
294            values.insert(node.id.clone(), node.baseline_value);
295        }
296
297        // Override with direct interventions
298        for (node_id, value) in interventions {
299            values.insert(node_id.clone(), *value);
300        }
301
302        // Build edge lookup: to_node -> list of (from_node, edge)
303        let mut incoming: HashMap<&str, Vec<&CausalEdge>> = HashMap::new();
304        for edge in &self.edges {
305            incoming.entry(&edge.to).or_default().push(edge);
306        }
307
308        // Propagate in topological order
309        for node_id in &self.topological_order {
310            // Skip nodes that are directly intervened upon
311            if interventions.contains_key(node_id) {
312                continue;
313            }
314
315            if let Some(edges) = incoming.get(node_id.as_str()) {
316                let mut total_effect = 0.0;
317                let mut has_effect = false;
318
319                for edge in edges {
320                    // Check lag: only apply if enough months have passed
321                    if month < edge.lag_months {
322                        continue;
323                    }
324
325                    let from_value = values.get(&edge.from).copied().unwrap_or(0.0);
326                    let baseline = self
327                        .find_node(&edge.from)
328                        .map(|n| n.baseline_value)
329                        .unwrap_or(0.0);
330
331                    // Compute the delta from baseline
332                    let delta = from_value - baseline;
333                    if delta.abs() < f64::EPSILON {
334                        continue;
335                    }
336
337                    // Apply transfer function to the delta
338                    let effect = edge.transfer.compute(delta) * edge.strength;
339                    total_effect += effect;
340                    has_effect = true;
341                }
342
343                if has_effect {
344                    let baseline = self
345                        .find_node(node_id)
346                        .map(|n| n.baseline_value)
347                        .unwrap_or(0.0);
348                    let mut new_value = baseline + total_effect;
349
350                    // Clamp to bounds
351                    if let Some(node) = self.find_node(node_id) {
352                        if let Some((min, max)) = node.bounds {
353                            new_value = new_value.clamp(min, max);
354                        }
355                    }
356
357                    values.insert(node_id.clone(), new_value);
358                }
359            }
360        }
361
362        values
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    fn make_node(id: &str, baseline: f64) -> CausalNode {
371        CausalNode {
372            id: id.to_string(),
373            label: id.to_string(),
374            category: NodeCategory::Operational,
375            baseline_value: baseline,
376            bounds: None,
377            interventionable: true,
378            config_bindings: vec![],
379        }
380    }
381
382    fn make_edge(from: &str, to: &str, transfer: TransferFunction) -> CausalEdge {
383        CausalEdge {
384            from: from.to_string(),
385            to: to.to_string(),
386            transfer,
387            lag_months: 0,
388            strength: 1.0,
389            mechanism: None,
390        }
391    }
392
393    #[test]
394    fn test_transfer_function_linear() {
395        let tf = TransferFunction::Linear {
396            coefficient: 0.5,
397            intercept: 1.0,
398        };
399        let result = tf.compute(2.0);
400        assert!((result - 2.0).abs() < f64::EPSILON); // 2.0 * 0.5 + 1.0 = 2.0
401    }
402
403    #[test]
404    fn test_transfer_function_logistic() {
405        let tf = TransferFunction::Logistic {
406            capacity: 1.0,
407            midpoint: 0.0,
408            steepness: 1.0,
409        };
410        // At midpoint, logistic returns capacity/2
411        let result = tf.compute(0.0);
412        assert!((result - 0.5).abs() < 0.001);
413    }
414
415    #[test]
416    fn test_transfer_function_exponential() {
417        let tf = TransferFunction::Exponential {
418            base: 1.0,
419            rate: 1.0,
420        };
421        // base * (1 + rate)^input = 1.0 * 2.0^3.0 = 8.0
422        let result = tf.compute(3.0);
423        assert!((result - 8.0).abs() < 0.001);
424    }
425
426    #[test]
427    fn test_transfer_function_step() {
428        let tf = TransferFunction::Step {
429            threshold: 5.0,
430            magnitude: 10.0,
431        };
432        assert!((tf.compute(3.0) - 0.0).abs() < f64::EPSILON);
433        assert!((tf.compute(6.0) - 10.0).abs() < f64::EPSILON);
434    }
435
436    #[test]
437    fn test_transfer_function_threshold() {
438        let tf = TransferFunction::Threshold {
439            threshold: 2.0,
440            magnitude: 10.0,
441            saturation: f64::INFINITY,
442        };
443        assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON); // below threshold
444                                                               // Above threshold: 10.0 * (3.0 - 2.0) / 2.0 = 5.0
445        assert!((tf.compute(3.0) - 5.0).abs() < 0.001);
446    }
447
448    #[test]
449    fn test_transfer_function_decay() {
450        let tf = TransferFunction::Decay {
451            initial: 100.0,
452            decay_rate: 0.5,
453        };
454        // At input=0: 100.0 * e^0 = 100.0
455        assert!((tf.compute(0.0) - 100.0).abs() < 0.001);
456        // At input=1: 100.0 * e^(-0.5) ≈ 60.65
457        assert!((tf.compute(1.0) - 60.653).abs() < 0.1);
458    }
459
460    #[test]
461    fn test_transfer_function_piecewise() {
462        let tf = TransferFunction::Piecewise {
463            points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0)],
464        };
465        // At 0.5: interpolate between (0,0) and (1,10) → 5.0
466        assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
467        // At 1.5: interpolate between (1,10) and (2,15) → 12.5
468        assert!((tf.compute(1.5) - 12.5).abs() < 0.001);
469        // Below range: clamp to first point
470        assert!((tf.compute(-1.0) - 0.0).abs() < 0.001);
471        // Above range: clamp to last point
472        assert!((tf.compute(3.0) - 15.0).abs() < 0.001);
473    }
474
475    #[test]
476    fn test_dag_validate_acyclic() {
477        let mut dag = CausalDAG {
478            nodes: vec![
479                make_node("a", 1.0),
480                make_node("b", 2.0),
481                make_node("c", 3.0),
482            ],
483            edges: vec![
484                make_edge(
485                    "a",
486                    "b",
487                    TransferFunction::Linear {
488                        coefficient: 1.0,
489                        intercept: 0.0,
490                    },
491                ),
492                make_edge(
493                    "b",
494                    "c",
495                    TransferFunction::Linear {
496                        coefficient: 1.0,
497                        intercept: 0.0,
498                    },
499                ),
500            ],
501            topological_order: vec![],
502        };
503        assert!(dag.validate().is_ok());
504        assert_eq!(dag.topological_order, vec!["a", "b", "c"]);
505    }
506
507    #[test]
508    fn test_dag_validate_cycle_detected() {
509        let mut dag = CausalDAG {
510            nodes: vec![make_node("a", 1.0), make_node("b", 2.0)],
511            edges: vec![
512                make_edge(
513                    "a",
514                    "b",
515                    TransferFunction::Linear {
516                        coefficient: 1.0,
517                        intercept: 0.0,
518                    },
519                ),
520                make_edge(
521                    "b",
522                    "a",
523                    TransferFunction::Linear {
524                        coefficient: 1.0,
525                        intercept: 0.0,
526                    },
527                ),
528            ],
529            topological_order: vec![],
530        };
531        assert!(matches!(dag.validate(), Err(CausalDAGError::CycleDetected)));
532    }
533
534    #[test]
535    fn test_dag_validate_unknown_node() {
536        let mut dag = CausalDAG {
537            nodes: vec![make_node("a", 1.0)],
538            edges: vec![make_edge(
539                "a",
540                "nonexistent",
541                TransferFunction::Linear {
542                    coefficient: 1.0,
543                    intercept: 0.0,
544                },
545            )],
546            topological_order: vec![],
547        };
548        assert!(matches!(
549            dag.validate(),
550            Err(CausalDAGError::UnknownNode(_))
551        ));
552    }
553
554    #[test]
555    fn test_dag_validate_duplicate_node() {
556        let mut dag = CausalDAG {
557            nodes: vec![make_node("a", 1.0), make_node("a", 2.0)],
558            edges: vec![],
559            topological_order: vec![],
560        };
561        assert!(matches!(
562            dag.validate(),
563            Err(CausalDAGError::DuplicateNode(_))
564        ));
565    }
566
567    #[test]
568    fn test_dag_propagate_chain() {
569        let mut dag = CausalDAG {
570            nodes: vec![
571                make_node("a", 10.0),
572                make_node("b", 5.0),
573                make_node("c", 0.0),
574            ],
575            edges: vec![
576                make_edge(
577                    "a",
578                    "b",
579                    TransferFunction::Linear {
580                        coefficient: 0.5,
581                        intercept: 0.0,
582                    },
583                ),
584                make_edge(
585                    "b",
586                    "c",
587                    TransferFunction::Linear {
588                        coefficient: 1.0,
589                        intercept: 0.0,
590                    },
591                ),
592            ],
593            topological_order: vec![],
594        };
595        dag.validate().unwrap();
596
597        // Intervene on A: set to 20.0 (delta = 10.0)
598        let mut interventions = HashMap::new();
599        interventions.insert("a".to_string(), 20.0);
600
601        let result = dag.propagate(&interventions, 0);
602        // A = 20.0 (directly set)
603        assert!((result["a"] - 20.0).abs() < 0.001);
604        // B baseline = 5.0, delta_a = 10.0, transfer = 10.0 * 0.5 + 0.0 = 5.0 → B = 5.0 + 5.0 = 10.0
605        assert!((result["b"] - 10.0).abs() < 0.001);
606        // C baseline = 0.0, delta_b = 5.0, transfer = 5.0 * 1.0 + 0.0 = 5.0 → C = 0.0 + 5.0 = 5.0
607        assert!((result["c"] - 5.0).abs() < 0.001);
608    }
609
610    #[test]
611    fn test_dag_propagate_with_lag() {
612        let mut dag = CausalDAG {
613            nodes: vec![make_node("a", 10.0), make_node("b", 5.0)],
614            edges: vec![CausalEdge {
615                from: "a".to_string(),
616                to: "b".to_string(),
617                transfer: TransferFunction::Linear {
618                    coefficient: 1.0,
619                    intercept: 0.0,
620                },
621                lag_months: 2,
622                strength: 1.0,
623                mechanism: None,
624            }],
625            topological_order: vec![],
626        };
627        dag.validate().unwrap();
628
629        let mut interventions = HashMap::new();
630        interventions.insert("a".to_string(), 20.0);
631
632        // Month 1: lag is 2, so no effect yet
633        let result = dag.propagate(&interventions, 1);
634        assert!((result["b"] - 5.0).abs() < 0.001); // unchanged from baseline
635
636        // Month 2: lag is met, effect propagates
637        let result = dag.propagate(&interventions, 2);
638        // delta_a = 10.0, transfer = 10.0, B = 5.0 + 10.0 = 15.0
639        assert!((result["b"] - 15.0).abs() < 0.001);
640    }
641
642    #[test]
643    fn test_dag_propagate_node_bounds_clamped() {
644        let mut dag = CausalDAG {
645            nodes: vec![make_node("a", 10.0), {
646                let mut n = make_node("b", 5.0);
647                n.bounds = Some((0.0, 8.0));
648                n
649            }],
650            edges: vec![make_edge(
651                "a",
652                "b",
653                TransferFunction::Linear {
654                    coefficient: 1.0,
655                    intercept: 0.0,
656                },
657            )],
658            topological_order: vec![],
659        };
660        dag.validate().unwrap();
661
662        let mut interventions = HashMap::new();
663        interventions.insert("a".to_string(), 20.0); // delta = 10.0 → B would be 15.0
664
665        let result = dag.propagate(&interventions, 0);
666        // B should be clamped to max bound of 8.0
667        assert!((result["b"] - 8.0).abs() < 0.001);
668    }
669
670    #[test]
671    fn test_transfer_function_serde() {
672        let tf = TransferFunction::Linear {
673            coefficient: 0.5,
674            intercept: 1.0,
675        };
676        let json = serde_json::to_string(&tf).unwrap();
677        let deserialized: TransferFunction = serde_json::from_str(&json).unwrap();
678        assert!((deserialized.compute(2.0) - 2.0).abs() < f64::EPSILON);
679    }
680
681    // ====================================================================
682    // Comprehensive transfer function tests (Task 12)
683    // ====================================================================
684
685    #[test]
686    fn test_transfer_function_linear_zero_coefficient() {
687        let tf = TransferFunction::Linear {
688            coefficient: 0.0,
689            intercept: 5.0,
690        };
691        // Any input → just the intercept
692        assert!((tf.compute(0.0) - 5.0).abs() < f64::EPSILON);
693        assert!((tf.compute(100.0) - 5.0).abs() < f64::EPSILON);
694        assert!((tf.compute(-100.0) - 5.0).abs() < f64::EPSILON);
695    }
696
697    #[test]
698    fn test_transfer_function_linear_negative_coefficient() {
699        let tf = TransferFunction::Linear {
700            coefficient: -2.0,
701            intercept: 10.0,
702        };
703        assert!((tf.compute(3.0) - 4.0).abs() < f64::EPSILON); // -6 + 10 = 4
704        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON); // -10 + 10 = 0
705    }
706
707    #[test]
708    fn test_transfer_function_exponential_zero_input() {
709        let tf = TransferFunction::Exponential {
710            base: 5.0,
711            rate: 0.5,
712        };
713        // (1+0.5)^0 = 1, so result = 5.0
714        assert!((tf.compute(0.0) - 5.0).abs() < 0.001);
715    }
716
717    #[test]
718    fn test_transfer_function_exponential_negative_rate() {
719        let tf = TransferFunction::Exponential {
720            base: 100.0,
721            rate: -0.5,
722        };
723        // (1 + (-0.5))^2 = 0.5^2 = 0.25, result = 25.0
724        assert!((tf.compute(2.0) - 25.0).abs() < 0.001);
725    }
726
727    #[test]
728    fn test_transfer_function_logistic_far_from_midpoint() {
729        let tf = TransferFunction::Logistic {
730            capacity: 10.0,
731            midpoint: 5.0,
732            steepness: 2.0,
733        };
734        // Far below midpoint → near 0
735        assert!(tf.compute(-10.0) < 0.01);
736        // Far above midpoint → near capacity
737        assert!((tf.compute(20.0) - 10.0).abs() < 0.01);
738        // At midpoint → capacity/2
739        assert!((tf.compute(5.0) - 5.0).abs() < 0.01);
740    }
741
742    #[test]
743    fn test_transfer_function_logistic_steepness_effect() {
744        // High steepness → sharper transition
745        let steep = TransferFunction::Logistic {
746            capacity: 1.0,
747            midpoint: 0.0,
748            steepness: 10.0,
749        };
750        let gentle = TransferFunction::Logistic {
751            capacity: 1.0,
752            midpoint: 0.0,
753            steepness: 0.5,
754        };
755        // Both should be ~0.5 at midpoint
756        assert!((steep.compute(0.0) - 0.5).abs() < 0.01);
757        assert!((gentle.compute(0.0) - 0.5).abs() < 0.01);
758        // Steep should be closer to 1.0 at input=1.0
759        assert!(steep.compute(1.0) > gentle.compute(1.0));
760    }
761
762    #[test]
763    fn test_transfer_function_inverse_logistic() {
764        let tf = TransferFunction::InverseLogistic {
765            capacity: 1.0,
766            midpoint: 0.0,
767            steepness: 1.0,
768        };
769        // At midpoint → capacity/2
770        assert!((tf.compute(0.0) - 0.5).abs() < 0.001);
771        // Inverse logistic decreases: far above midpoint → near 0
772        assert!(tf.compute(10.0) < 0.01);
773        // Far below midpoint → near capacity
774        assert!((tf.compute(-10.0) - 1.0).abs() < 0.01);
775    }
776
777    #[test]
778    fn test_transfer_function_inverse_logistic_symmetry() {
779        let logistic = TransferFunction::Logistic {
780            capacity: 1.0,
781            midpoint: 0.0,
782            steepness: 1.0,
783        };
784        let inverse = TransferFunction::InverseLogistic {
785            capacity: 1.0,
786            midpoint: 0.0,
787            steepness: 1.0,
788        };
789        // Logistic + InverseLogistic should sum to capacity at any point
790        for x in [-5.0, -1.0, 0.0, 1.0, 5.0] {
791            let sum = logistic.compute(x) + inverse.compute(x);
792            assert!((sum - 1.0).abs() < 0.001, "Sum at x={} was {}", x, sum);
793        }
794    }
795
796    #[test]
797    fn test_transfer_function_step_at_threshold() {
798        let tf = TransferFunction::Step {
799            threshold: 5.0,
800            magnitude: 10.0,
801        };
802        // At exactly the threshold, should be 0 (not strictly greater)
803        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
804        // Just above threshold
805        assert!((tf.compute(5.001) - 10.0).abs() < f64::EPSILON);
806    }
807
808    #[test]
809    fn test_transfer_function_step_negative_magnitude() {
810        let tf = TransferFunction::Step {
811            threshold: 0.0,
812            magnitude: -5.0,
813        };
814        assert!((tf.compute(-1.0) - 0.0).abs() < f64::EPSILON);
815        assert!((tf.compute(1.0) - (-5.0)).abs() < f64::EPSILON);
816    }
817
818    #[test]
819    fn test_transfer_function_threshold_with_saturation() {
820        let tf = TransferFunction::Threshold {
821            threshold: 2.0,
822            magnitude: 10.0,
823            saturation: 8.0,
824        };
825        // Below threshold: 0
826        assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON);
827        // Just above threshold: 10.0 * (2.5 - 2.0) / 2.0 = 2.5
828        assert!((tf.compute(2.5) - 2.5).abs() < 0.001);
829        // Way above threshold without saturation: 10.0 * (100.0 - 2.0) / 2.0 = 490
830        // But capped at saturation=8.0
831        assert!((tf.compute(100.0) - 8.0).abs() < 0.001);
832    }
833
834    #[test]
835    fn test_transfer_function_threshold_infinite_saturation() {
836        let tf = TransferFunction::Threshold {
837            threshold: 1.0,
838            magnitude: 5.0,
839            saturation: f64::INFINITY,
840        };
841        // No saturation cap: grows linearly
842        // 5.0 * (100.0 - 1.0) / 1.0 = 495.0
843        assert!((tf.compute(100.0) - 495.0).abs() < 0.001);
844    }
845
846    #[test]
847    fn test_transfer_function_decay_large_input() {
848        let tf = TransferFunction::Decay {
849            initial: 100.0,
850            decay_rate: 1.0,
851        };
852        // Large input → approaches 0
853        assert!(tf.compute(10.0) < 0.01);
854        assert!(tf.compute(20.0) < 0.0001);
855    }
856
857    #[test]
858    fn test_transfer_function_decay_zero_rate() {
859        let tf = TransferFunction::Decay {
860            initial: 50.0,
861            decay_rate: 0.0,
862        };
863        // No decay → constant
864        assert!((tf.compute(0.0) - 50.0).abs() < f64::EPSILON);
865        assert!((tf.compute(100.0) - 50.0).abs() < f64::EPSILON);
866    }
867
868    #[test]
869    fn test_transfer_function_piecewise_single_point() {
870        let tf = TransferFunction::Piecewise {
871            points: vec![(5.0, 42.0)],
872        };
873        // Single point → always returns that value
874        assert!((tf.compute(0.0) - 42.0).abs() < f64::EPSILON);
875        assert!((tf.compute(100.0) - 42.0).abs() < f64::EPSILON);
876    }
877
878    #[test]
879    fn test_transfer_function_piecewise_empty() {
880        let tf = TransferFunction::Piecewise { points: vec![] };
881        assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
882    }
883
884    #[test]
885    fn test_transfer_function_piecewise_exact_points() {
886        let tf = TransferFunction::Piecewise {
887            points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0), (3.0, 30.0)],
888        };
889        // At exact breakpoints
890        assert!((tf.compute(0.0) - 0.0).abs() < 0.001);
891        assert!((tf.compute(1.0) - 10.0).abs() < 0.001);
892        assert!((tf.compute(2.0) - 15.0).abs() < 0.001);
893        assert!((tf.compute(3.0) - 30.0).abs() < 0.001);
894    }
895
896    #[test]
897    fn test_transfer_function_piecewise_unsorted_points() {
898        // Points given out of order — should still interpolate correctly
899        let tf = TransferFunction::Piecewise {
900            points: vec![(2.0, 20.0), (0.0, 0.0), (1.0, 10.0)],
901        };
902        assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
903        assert!((tf.compute(1.5) - 15.0).abs() < 0.001);
904    }
905}