1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3use thiserror::Error;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CausalDAG {
9 pub nodes: Vec<CausalNode>,
10 pub edges: Vec<CausalEdge>,
11 #[serde(skip)]
13 pub topological_order: Vec<String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalNode {
18 pub id: String,
20 pub label: String,
21 pub category: NodeCategory,
22 pub baseline_value: f64,
24 pub bounds: Option<(f64, f64)>,
26 #[serde(default = "default_true")]
28 pub interventionable: bool,
29 #[serde(default)]
31 pub config_bindings: Vec<String>,
32}
33
34fn default_true() -> bool {
35 true
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum NodeCategory {
41 Macro,
42 Operational,
43 Control,
44 Financial,
45 Behavioral,
46 Regulatory,
47 Outcome,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct CausalEdge {
52 pub from: String,
53 pub to: String,
54 pub transfer: TransferFunction,
55 #[serde(default)]
57 pub lag_months: u32,
58 #[serde(default = "default_strength")]
60 pub strength: f64,
61 pub mechanism: Option<String>,
63}
64
65fn default_strength() -> f64 {
66 1.0
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum TransferFunction {
72 Linear {
74 coefficient: f64,
75 #[serde(default)]
76 intercept: f64,
77 },
78 Exponential { base: f64, rate: f64 },
80 Logistic {
82 capacity: f64,
83 midpoint: f64,
84 steepness: f64,
85 },
86 InverseLogistic {
88 capacity: f64,
89 midpoint: f64,
90 steepness: f64,
91 },
92 Step { threshold: f64, magnitude: f64 },
94 Threshold {
96 threshold: f64,
97 magnitude: f64,
98 #[serde(default = "default_saturation")]
99 saturation: f64,
100 },
101 Decay { initial: f64, decay_rate: f64 },
103 Piecewise { points: Vec<(f64, f64)> },
105}
106
107fn default_saturation() -> f64 {
108 f64::INFINITY
109}
110
111impl TransferFunction {
112 pub fn compute(&self, input: f64) -> f64 {
114 match self {
115 TransferFunction::Linear {
116 coefficient,
117 intercept,
118 } => input * coefficient + intercept,
119
120 TransferFunction::Exponential { base, rate } => base * (1.0 + rate).powf(input),
121
122 TransferFunction::Logistic {
123 capacity,
124 midpoint,
125 steepness,
126 } => capacity / (1.0 + (-steepness * (input - midpoint)).exp()),
127
128 TransferFunction::InverseLogistic {
129 capacity,
130 midpoint,
131 steepness,
132 } => capacity / (1.0 + (steepness * (input - midpoint)).exp()),
133
134 TransferFunction::Step {
135 threshold,
136 magnitude,
137 } => {
138 if input > *threshold {
139 *magnitude
140 } else {
141 0.0
142 }
143 }
144
145 TransferFunction::Threshold {
146 threshold,
147 magnitude,
148 saturation,
149 } => {
150 if input > *threshold {
151 (magnitude * (input - threshold) / threshold).min(*saturation)
152 } else {
153 0.0
154 }
155 }
156
157 TransferFunction::Decay {
158 initial,
159 decay_rate,
160 } => initial * (-decay_rate * input).exp(),
161
162 TransferFunction::Piecewise { points } => {
163 if points.is_empty() {
164 return 0.0;
165 }
166 if points.len() == 1 {
167 return points[0].1;
168 }
169
170 let mut sorted = points.clone();
172 sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
173
174 if input <= sorted[0].0 {
176 return sorted[0].1;
177 }
178 if input >= sorted[sorted.len() - 1].0 {
179 return sorted[sorted.len() - 1].1;
180 }
181
182 for window in sorted.windows(2) {
184 let (x0, y0) = window[0];
185 let (x1, y1) = window[1];
186 if input >= x0 && input <= x1 {
187 let t = (input - x0) / (x1 - x0);
188 return y0 + t * (y1 - y0);
189 }
190 }
191
192 sorted[sorted.len() - 1].1
193 }
194 }
195 }
196}
197
198#[derive(Debug, Error)]
200pub enum CausalDAGError {
201 #[error("cycle detected in causal DAG")]
202 CycleDetected,
203 #[error("unknown node referenced in edge: {0}")]
204 UnknownNode(String),
205 #[error("duplicate node ID: {0}")]
206 DuplicateNode(String),
207 #[error("node '{0}' is not interventionable")]
208 NonInterventionable(String),
209}
210
211impl CausalDAG {
212 pub fn validate(&mut self) -> Result<(), CausalDAGError> {
214 let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
215
216 let mut seen = HashSet::new();
218 for node in &self.nodes {
219 if !seen.insert(&node.id) {
220 return Err(CausalDAGError::DuplicateNode(node.id.clone()));
221 }
222 }
223
224 for edge in &self.edges {
226 if !node_ids.contains(edge.from.as_str()) {
227 return Err(CausalDAGError::UnknownNode(edge.from.clone()));
228 }
229 if !node_ids.contains(edge.to.as_str()) {
230 return Err(CausalDAGError::UnknownNode(edge.to.clone()));
231 }
232 }
233
234 let mut in_degree: HashMap<&str, usize> = HashMap::new();
236 let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
237
238 for node in &self.nodes {
239 in_degree.insert(&node.id, 0);
240 adjacency.insert(&node.id, Vec::new());
241 }
242
243 for edge in &self.edges {
244 *in_degree.entry(&edge.to).or_insert(0) += 1;
245 adjacency.entry(&edge.from).or_default().push(&edge.to);
246 }
247
248 let mut queue: VecDeque<&str> = VecDeque::new();
249 for (node, °ree) in &in_degree {
250 if degree == 0 {
251 queue.push_back(node);
252 }
253 }
254
255 let mut order = Vec::new();
256 while let Some(node) = queue.pop_front() {
257 order.push(node.to_string());
258 if let Some(neighbors) = adjacency.get(node) {
259 for &neighbor in neighbors {
260 if let Some(degree) = in_degree.get_mut(neighbor) {
261 *degree -= 1;
262 if *degree == 0 {
263 queue.push_back(neighbor);
264 }
265 }
266 }
267 }
268 }
269
270 if order.len() != self.nodes.len() {
271 return Err(CausalDAGError::CycleDetected);
272 }
273
274 self.topological_order = order;
275 Ok(())
276 }
277
278 pub fn find_node(&self, id: &str) -> Option<&CausalNode> {
280 self.nodes.iter().find(|n| n.id == id)
281 }
282
283 pub fn propagate(
286 &self,
287 interventions: &HashMap<String, f64>,
288 month: u32,
289 ) -> HashMap<String, f64> {
290 let mut values: HashMap<String, f64> = HashMap::new();
291
292 for node in &self.nodes {
294 values.insert(node.id.clone(), node.baseline_value);
295 }
296
297 for (node_id, value) in interventions {
299 values.insert(node_id.clone(), *value);
300 }
301
302 let mut incoming: HashMap<&str, Vec<&CausalEdge>> = HashMap::new();
304 for edge in &self.edges {
305 incoming.entry(&edge.to).or_default().push(edge);
306 }
307
308 for node_id in &self.topological_order {
310 if interventions.contains_key(node_id) {
312 continue;
313 }
314
315 if let Some(edges) = incoming.get(node_id.as_str()) {
316 let mut total_effect = 0.0;
317 let mut has_effect = false;
318
319 for edge in edges {
320 if month < edge.lag_months {
322 continue;
323 }
324
325 let from_value = values.get(&edge.from).copied().unwrap_or(0.0);
326 let baseline = self
327 .find_node(&edge.from)
328 .map(|n| n.baseline_value)
329 .unwrap_or(0.0);
330
331 let delta = from_value - baseline;
333 if delta.abs() < f64::EPSILON {
334 continue;
335 }
336
337 let effect = edge.transfer.compute(delta) * edge.strength;
339 total_effect += effect;
340 has_effect = true;
341 }
342
343 if has_effect {
344 let baseline = self
345 .find_node(node_id)
346 .map(|n| n.baseline_value)
347 .unwrap_or(0.0);
348 let mut new_value = baseline + total_effect;
349
350 if let Some(node) = self.find_node(node_id) {
352 if let Some((min, max)) = node.bounds {
353 new_value = new_value.clamp(min, max);
354 }
355 }
356
357 values.insert(node_id.clone(), new_value);
358 }
359 }
360 }
361
362 values
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 fn make_node(id: &str, baseline: f64) -> CausalNode {
371 CausalNode {
372 id: id.to_string(),
373 label: id.to_string(),
374 category: NodeCategory::Operational,
375 baseline_value: baseline,
376 bounds: None,
377 interventionable: true,
378 config_bindings: vec![],
379 }
380 }
381
382 fn make_edge(from: &str, to: &str, transfer: TransferFunction) -> CausalEdge {
383 CausalEdge {
384 from: from.to_string(),
385 to: to.to_string(),
386 transfer,
387 lag_months: 0,
388 strength: 1.0,
389 mechanism: None,
390 }
391 }
392
393 #[test]
394 fn test_transfer_function_linear() {
395 let tf = TransferFunction::Linear {
396 coefficient: 0.5,
397 intercept: 1.0,
398 };
399 let result = tf.compute(2.0);
400 assert!((result - 2.0).abs() < f64::EPSILON); }
402
403 #[test]
404 fn test_transfer_function_logistic() {
405 let tf = TransferFunction::Logistic {
406 capacity: 1.0,
407 midpoint: 0.0,
408 steepness: 1.0,
409 };
410 let result = tf.compute(0.0);
412 assert!((result - 0.5).abs() < 0.001);
413 }
414
415 #[test]
416 fn test_transfer_function_exponential() {
417 let tf = TransferFunction::Exponential {
418 base: 1.0,
419 rate: 1.0,
420 };
421 let result = tf.compute(3.0);
423 assert!((result - 8.0).abs() < 0.001);
424 }
425
426 #[test]
427 fn test_transfer_function_step() {
428 let tf = TransferFunction::Step {
429 threshold: 5.0,
430 magnitude: 10.0,
431 };
432 assert!((tf.compute(3.0) - 0.0).abs() < f64::EPSILON);
433 assert!((tf.compute(6.0) - 10.0).abs() < f64::EPSILON);
434 }
435
436 #[test]
437 fn test_transfer_function_threshold() {
438 let tf = TransferFunction::Threshold {
439 threshold: 2.0,
440 magnitude: 10.0,
441 saturation: f64::INFINITY,
442 };
443 assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON); assert!((tf.compute(3.0) - 5.0).abs() < 0.001);
446 }
447
448 #[test]
449 fn test_transfer_function_decay() {
450 let tf = TransferFunction::Decay {
451 initial: 100.0,
452 decay_rate: 0.5,
453 };
454 assert!((tf.compute(0.0) - 100.0).abs() < 0.001);
456 assert!((tf.compute(1.0) - 60.653).abs() < 0.1);
458 }
459
460 #[test]
461 fn test_transfer_function_piecewise() {
462 let tf = TransferFunction::Piecewise {
463 points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0)],
464 };
465 assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
467 assert!((tf.compute(1.5) - 12.5).abs() < 0.001);
469 assert!((tf.compute(-1.0) - 0.0).abs() < 0.001);
471 assert!((tf.compute(3.0) - 15.0).abs() < 0.001);
473 }
474
475 #[test]
476 fn test_dag_validate_acyclic() {
477 let mut dag = CausalDAG {
478 nodes: vec![
479 make_node("a", 1.0),
480 make_node("b", 2.0),
481 make_node("c", 3.0),
482 ],
483 edges: vec![
484 make_edge(
485 "a",
486 "b",
487 TransferFunction::Linear {
488 coefficient: 1.0,
489 intercept: 0.0,
490 },
491 ),
492 make_edge(
493 "b",
494 "c",
495 TransferFunction::Linear {
496 coefficient: 1.0,
497 intercept: 0.0,
498 },
499 ),
500 ],
501 topological_order: vec![],
502 };
503 assert!(dag.validate().is_ok());
504 assert_eq!(dag.topological_order, vec!["a", "b", "c"]);
505 }
506
507 #[test]
508 fn test_dag_validate_cycle_detected() {
509 let mut dag = CausalDAG {
510 nodes: vec![make_node("a", 1.0), make_node("b", 2.0)],
511 edges: vec![
512 make_edge(
513 "a",
514 "b",
515 TransferFunction::Linear {
516 coefficient: 1.0,
517 intercept: 0.0,
518 },
519 ),
520 make_edge(
521 "b",
522 "a",
523 TransferFunction::Linear {
524 coefficient: 1.0,
525 intercept: 0.0,
526 },
527 ),
528 ],
529 topological_order: vec![],
530 };
531 assert!(matches!(dag.validate(), Err(CausalDAGError::CycleDetected)));
532 }
533
534 #[test]
535 fn test_dag_validate_unknown_node() {
536 let mut dag = CausalDAG {
537 nodes: vec![make_node("a", 1.0)],
538 edges: vec![make_edge(
539 "a",
540 "nonexistent",
541 TransferFunction::Linear {
542 coefficient: 1.0,
543 intercept: 0.0,
544 },
545 )],
546 topological_order: vec![],
547 };
548 assert!(matches!(
549 dag.validate(),
550 Err(CausalDAGError::UnknownNode(_))
551 ));
552 }
553
554 #[test]
555 fn test_dag_validate_duplicate_node() {
556 let mut dag = CausalDAG {
557 nodes: vec![make_node("a", 1.0), make_node("a", 2.0)],
558 edges: vec![],
559 topological_order: vec![],
560 };
561 assert!(matches!(
562 dag.validate(),
563 Err(CausalDAGError::DuplicateNode(_))
564 ));
565 }
566
567 #[test]
568 fn test_dag_propagate_chain() {
569 let mut dag = CausalDAG {
570 nodes: vec![
571 make_node("a", 10.0),
572 make_node("b", 5.0),
573 make_node("c", 0.0),
574 ],
575 edges: vec![
576 make_edge(
577 "a",
578 "b",
579 TransferFunction::Linear {
580 coefficient: 0.5,
581 intercept: 0.0,
582 },
583 ),
584 make_edge(
585 "b",
586 "c",
587 TransferFunction::Linear {
588 coefficient: 1.0,
589 intercept: 0.0,
590 },
591 ),
592 ],
593 topological_order: vec![],
594 };
595 dag.validate().unwrap();
596
597 let mut interventions = HashMap::new();
599 interventions.insert("a".to_string(), 20.0);
600
601 let result = dag.propagate(&interventions, 0);
602 assert!((result["a"] - 20.0).abs() < 0.001);
604 assert!((result["b"] - 10.0).abs() < 0.001);
606 assert!((result["c"] - 5.0).abs() < 0.001);
608 }
609
610 #[test]
611 fn test_dag_propagate_with_lag() {
612 let mut dag = CausalDAG {
613 nodes: vec![make_node("a", 10.0), make_node("b", 5.0)],
614 edges: vec![CausalEdge {
615 from: "a".to_string(),
616 to: "b".to_string(),
617 transfer: TransferFunction::Linear {
618 coefficient: 1.0,
619 intercept: 0.0,
620 },
621 lag_months: 2,
622 strength: 1.0,
623 mechanism: None,
624 }],
625 topological_order: vec![],
626 };
627 dag.validate().unwrap();
628
629 let mut interventions = HashMap::new();
630 interventions.insert("a".to_string(), 20.0);
631
632 let result = dag.propagate(&interventions, 1);
634 assert!((result["b"] - 5.0).abs() < 0.001); let result = dag.propagate(&interventions, 2);
638 assert!((result["b"] - 15.0).abs() < 0.001);
640 }
641
642 #[test]
643 fn test_dag_propagate_node_bounds_clamped() {
644 let mut dag = CausalDAG {
645 nodes: vec![make_node("a", 10.0), {
646 let mut n = make_node("b", 5.0);
647 n.bounds = Some((0.0, 8.0));
648 n
649 }],
650 edges: vec![make_edge(
651 "a",
652 "b",
653 TransferFunction::Linear {
654 coefficient: 1.0,
655 intercept: 0.0,
656 },
657 )],
658 topological_order: vec![],
659 };
660 dag.validate().unwrap();
661
662 let mut interventions = HashMap::new();
663 interventions.insert("a".to_string(), 20.0); let result = dag.propagate(&interventions, 0);
666 assert!((result["b"] - 8.0).abs() < 0.001);
668 }
669
670 #[test]
671 fn test_transfer_function_serde() {
672 let tf = TransferFunction::Linear {
673 coefficient: 0.5,
674 intercept: 1.0,
675 };
676 let json = serde_json::to_string(&tf).unwrap();
677 let deserialized: TransferFunction = serde_json::from_str(&json).unwrap();
678 assert!((deserialized.compute(2.0) - 2.0).abs() < f64::EPSILON);
679 }
680
681 #[test]
686 fn test_transfer_function_linear_zero_coefficient() {
687 let tf = TransferFunction::Linear {
688 coefficient: 0.0,
689 intercept: 5.0,
690 };
691 assert!((tf.compute(0.0) - 5.0).abs() < f64::EPSILON);
693 assert!((tf.compute(100.0) - 5.0).abs() < f64::EPSILON);
694 assert!((tf.compute(-100.0) - 5.0).abs() < f64::EPSILON);
695 }
696
697 #[test]
698 fn test_transfer_function_linear_negative_coefficient() {
699 let tf = TransferFunction::Linear {
700 coefficient: -2.0,
701 intercept: 10.0,
702 };
703 assert!((tf.compute(3.0) - 4.0).abs() < f64::EPSILON); assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON); }
706
707 #[test]
708 fn test_transfer_function_exponential_zero_input() {
709 let tf = TransferFunction::Exponential {
710 base: 5.0,
711 rate: 0.5,
712 };
713 assert!((tf.compute(0.0) - 5.0).abs() < 0.001);
715 }
716
717 #[test]
718 fn test_transfer_function_exponential_negative_rate() {
719 let tf = TransferFunction::Exponential {
720 base: 100.0,
721 rate: -0.5,
722 };
723 assert!((tf.compute(2.0) - 25.0).abs() < 0.001);
725 }
726
727 #[test]
728 fn test_transfer_function_logistic_far_from_midpoint() {
729 let tf = TransferFunction::Logistic {
730 capacity: 10.0,
731 midpoint: 5.0,
732 steepness: 2.0,
733 };
734 assert!(tf.compute(-10.0) < 0.01);
736 assert!((tf.compute(20.0) - 10.0).abs() < 0.01);
738 assert!((tf.compute(5.0) - 5.0).abs() < 0.01);
740 }
741
742 #[test]
743 fn test_transfer_function_logistic_steepness_effect() {
744 let steep = TransferFunction::Logistic {
746 capacity: 1.0,
747 midpoint: 0.0,
748 steepness: 10.0,
749 };
750 let gentle = TransferFunction::Logistic {
751 capacity: 1.0,
752 midpoint: 0.0,
753 steepness: 0.5,
754 };
755 assert!((steep.compute(0.0) - 0.5).abs() < 0.01);
757 assert!((gentle.compute(0.0) - 0.5).abs() < 0.01);
758 assert!(steep.compute(1.0) > gentle.compute(1.0));
760 }
761
762 #[test]
763 fn test_transfer_function_inverse_logistic() {
764 let tf = TransferFunction::InverseLogistic {
765 capacity: 1.0,
766 midpoint: 0.0,
767 steepness: 1.0,
768 };
769 assert!((tf.compute(0.0) - 0.5).abs() < 0.001);
771 assert!(tf.compute(10.0) < 0.01);
773 assert!((tf.compute(-10.0) - 1.0).abs() < 0.01);
775 }
776
777 #[test]
778 fn test_transfer_function_inverse_logistic_symmetry() {
779 let logistic = TransferFunction::Logistic {
780 capacity: 1.0,
781 midpoint: 0.0,
782 steepness: 1.0,
783 };
784 let inverse = TransferFunction::InverseLogistic {
785 capacity: 1.0,
786 midpoint: 0.0,
787 steepness: 1.0,
788 };
789 for x in [-5.0, -1.0, 0.0, 1.0, 5.0] {
791 let sum = logistic.compute(x) + inverse.compute(x);
792 assert!((sum - 1.0).abs() < 0.001, "Sum at x={} was {}", x, sum);
793 }
794 }
795
796 #[test]
797 fn test_transfer_function_step_at_threshold() {
798 let tf = TransferFunction::Step {
799 threshold: 5.0,
800 magnitude: 10.0,
801 };
802 assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
804 assert!((tf.compute(5.001) - 10.0).abs() < f64::EPSILON);
806 }
807
808 #[test]
809 fn test_transfer_function_step_negative_magnitude() {
810 let tf = TransferFunction::Step {
811 threshold: 0.0,
812 magnitude: -5.0,
813 };
814 assert!((tf.compute(-1.0) - 0.0).abs() < f64::EPSILON);
815 assert!((tf.compute(1.0) - (-5.0)).abs() < f64::EPSILON);
816 }
817
818 #[test]
819 fn test_transfer_function_threshold_with_saturation() {
820 let tf = TransferFunction::Threshold {
821 threshold: 2.0,
822 magnitude: 10.0,
823 saturation: 8.0,
824 };
825 assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON);
827 assert!((tf.compute(2.5) - 2.5).abs() < 0.001);
829 assert!((tf.compute(100.0) - 8.0).abs() < 0.001);
832 }
833
834 #[test]
835 fn test_transfer_function_threshold_infinite_saturation() {
836 let tf = TransferFunction::Threshold {
837 threshold: 1.0,
838 magnitude: 5.0,
839 saturation: f64::INFINITY,
840 };
841 assert!((tf.compute(100.0) - 495.0).abs() < 0.001);
844 }
845
846 #[test]
847 fn test_transfer_function_decay_large_input() {
848 let tf = TransferFunction::Decay {
849 initial: 100.0,
850 decay_rate: 1.0,
851 };
852 assert!(tf.compute(10.0) < 0.01);
854 assert!(tf.compute(20.0) < 0.0001);
855 }
856
857 #[test]
858 fn test_transfer_function_decay_zero_rate() {
859 let tf = TransferFunction::Decay {
860 initial: 50.0,
861 decay_rate: 0.0,
862 };
863 assert!((tf.compute(0.0) - 50.0).abs() < f64::EPSILON);
865 assert!((tf.compute(100.0) - 50.0).abs() < f64::EPSILON);
866 }
867
868 #[test]
869 fn test_transfer_function_piecewise_single_point() {
870 let tf = TransferFunction::Piecewise {
871 points: vec![(5.0, 42.0)],
872 };
873 assert!((tf.compute(0.0) - 42.0).abs() < f64::EPSILON);
875 assert!((tf.compute(100.0) - 42.0).abs() < f64::EPSILON);
876 }
877
878 #[test]
879 fn test_transfer_function_piecewise_empty() {
880 let tf = TransferFunction::Piecewise { points: vec![] };
881 assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
882 }
883
884 #[test]
885 fn test_transfer_function_piecewise_exact_points() {
886 let tf = TransferFunction::Piecewise {
887 points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0), (3.0, 30.0)],
888 };
889 assert!((tf.compute(0.0) - 0.0).abs() < 0.001);
891 assert!((tf.compute(1.0) - 10.0).abs() < 0.001);
892 assert!((tf.compute(2.0) - 15.0).abs() < 0.001);
893 assert!((tf.compute(3.0) - 30.0).abs() < 0.001);
894 }
895
896 #[test]
897 fn test_transfer_function_piecewise_unsorted_points() {
898 let tf = TransferFunction::Piecewise {
900 points: vec![(2.0, 20.0), (0.0, 0.0), (1.0, 10.0)],
901 };
902 assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
903 assert!((tf.compute(1.5) - 15.0).abs() < 0.001);
904 }
905}