1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3use thiserror::Error;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CausalDAG {
9 pub nodes: Vec<CausalNode>,
10 pub edges: Vec<CausalEdge>,
11 #[serde(skip)]
13 pub topological_order: Vec<String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalNode {
18 pub id: String,
20 pub label: String,
21 pub category: NodeCategory,
22 pub baseline_value: f64,
24 pub bounds: Option<(f64, f64)>,
26 #[serde(default = "default_true")]
28 pub interventionable: bool,
29 #[serde(default)]
31 pub config_bindings: Vec<String>,
32}
33
34fn default_true() -> bool {
35 true
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum NodeCategory {
41 Macro,
42 Operational,
43 Control,
44 Financial,
45 Behavioral,
46 Regulatory,
47 Outcome,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct CausalEdge {
52 pub from: String,
53 pub to: String,
54 pub transfer: TransferFunction,
55 #[serde(default)]
57 pub lag_months: u32,
58 #[serde(default = "default_strength")]
60 pub strength: f64,
61 pub mechanism: Option<String>,
63}
64
65fn default_strength() -> f64 {
66 1.0
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "snake_case")]
71pub enum TransferFunction {
72 Linear {
74 coefficient: f64,
75 #[serde(default)]
76 intercept: f64,
77 },
78 Exponential { base: f64, rate: f64 },
80 Logistic {
82 capacity: f64,
83 midpoint: f64,
84 steepness: f64,
85 },
86 InverseLogistic {
88 capacity: f64,
89 midpoint: f64,
90 steepness: f64,
91 },
92 Step { threshold: f64, magnitude: f64 },
94 Threshold {
96 threshold: f64,
97 magnitude: f64,
98 #[serde(default = "default_saturation")]
99 saturation: f64,
100 },
101 Decay { initial: f64, decay_rate: f64 },
103 Piecewise { points: Vec<(f64, f64)> },
105}
106
107fn default_saturation() -> f64 {
108 f64::INFINITY
109}
110
111impl TransferFunction {
112 pub fn compute(&self, input: f64) -> f64 {
114 match self {
115 TransferFunction::Linear {
116 coefficient,
117 intercept,
118 } => input * coefficient + intercept,
119
120 TransferFunction::Exponential { base, rate } => base * (1.0 + rate).powf(input),
121
122 TransferFunction::Logistic {
123 capacity,
124 midpoint,
125 steepness,
126 } => capacity / (1.0 + (-steepness * (input - midpoint)).exp()),
127
128 TransferFunction::InverseLogistic {
129 capacity,
130 midpoint,
131 steepness,
132 } => capacity / (1.0 + (steepness * (input - midpoint)).exp()),
133
134 TransferFunction::Step {
135 threshold,
136 magnitude,
137 } => {
138 if input > *threshold {
139 *magnitude
140 } else {
141 0.0
142 }
143 }
144
145 TransferFunction::Threshold {
146 threshold,
147 magnitude,
148 saturation,
149 } => {
150 if input > *threshold {
151 (magnitude * (input - threshold) / threshold).min(*saturation)
152 } else {
153 0.0
154 }
155 }
156
157 TransferFunction::Decay {
158 initial,
159 decay_rate,
160 } => initial * (-decay_rate * input).exp(),
161
162 TransferFunction::Piecewise { points } => {
163 if points.is_empty() {
164 return 0.0;
165 }
166 if points.len() == 1 {
167 return points[0].1;
168 }
169
170 let mut sorted = points.clone();
172 sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
173
174 if input <= sorted[0].0 {
176 return sorted[0].1;
177 }
178 if input >= sorted[sorted.len() - 1].0 {
179 return sorted[sorted.len() - 1].1;
180 }
181
182 for window in sorted.windows(2) {
184 let (x0, y0) = window[0];
185 let (x1, y1) = window[1];
186 if input >= x0 && input <= x1 {
187 let t = (input - x0) / (x1 - x0);
188 return y0 + t * (y1 - y0);
189 }
190 }
191
192 sorted[sorted.len() - 1].1
193 }
194 }
195 }
196}
197
198#[derive(Debug, Error)]
200pub enum CausalDAGError {
201 #[error("cycle detected in causal DAG")]
202 CycleDetected,
203 #[error("unknown node referenced in edge: {0}")]
204 UnknownNode(String),
205 #[error("duplicate node ID: {0}")]
206 DuplicateNode(String),
207 #[error("node '{0}' is not interventionable")]
208 NonInterventionable(String),
209}
210
211impl CausalDAG {
212 pub fn validate(&mut self) -> Result<(), CausalDAGError> {
214 let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
215
216 let mut seen = HashSet::new();
218 for node in &self.nodes {
219 if !seen.insert(&node.id) {
220 return Err(CausalDAGError::DuplicateNode(node.id.clone()));
221 }
222 }
223
224 for edge in &self.edges {
226 if !node_ids.contains(edge.from.as_str()) {
227 return Err(CausalDAGError::UnknownNode(edge.from.clone()));
228 }
229 if !node_ids.contains(edge.to.as_str()) {
230 return Err(CausalDAGError::UnknownNode(edge.to.clone()));
231 }
232 }
233
234 let mut in_degree: HashMap<&str, usize> = HashMap::new();
236 let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
237
238 for node in &self.nodes {
239 in_degree.insert(&node.id, 0);
240 adjacency.insert(&node.id, Vec::new());
241 }
242
243 for edge in &self.edges {
244 *in_degree.entry(&edge.to).or_insert(0) += 1;
245 adjacency.entry(&edge.from).or_default().push(&edge.to);
246 }
247
248 let mut queue: VecDeque<&str> = VecDeque::new();
249 for (node, °ree) in &in_degree {
250 if degree == 0 {
251 queue.push_back(node);
252 }
253 }
254
255 let mut order = Vec::new();
256 while let Some(node) = queue.pop_front() {
257 order.push(node.to_string());
258 if let Some(neighbors) = adjacency.get(node) {
259 for &neighbor in neighbors {
260 if let Some(degree) = in_degree.get_mut(neighbor) {
261 *degree -= 1;
262 if *degree == 0 {
263 queue.push_back(neighbor);
264 }
265 }
266 }
267 }
268 }
269
270 if order.len() != self.nodes.len() {
271 return Err(CausalDAGError::CycleDetected);
272 }
273
274 self.topological_order = order;
275 Ok(())
276 }
277
278 pub fn find_node(&self, id: &str) -> Option<&CausalNode> {
280 self.nodes.iter().find(|n| n.id == id)
281 }
282
283 pub fn propagate(
286 &self,
287 interventions: &HashMap<String, f64>,
288 month: u32,
289 ) -> HashMap<String, f64> {
290 let mut values: HashMap<String, f64> = HashMap::new();
291
292 for node in &self.nodes {
294 values.insert(node.id.clone(), node.baseline_value);
295 }
296
297 for (node_id, value) in interventions {
299 values.insert(node_id.clone(), *value);
300 }
301
302 let mut incoming: HashMap<&str, Vec<&CausalEdge>> = HashMap::new();
304 for edge in &self.edges {
305 incoming.entry(&edge.to).or_default().push(edge);
306 }
307
308 for node_id in &self.topological_order {
310 if interventions.contains_key(node_id) {
312 continue;
313 }
314
315 if let Some(edges) = incoming.get(node_id.as_str()) {
316 let mut total_effect = 0.0;
317 let mut has_effect = false;
318
319 for edge in edges {
320 if month < edge.lag_months {
322 continue;
323 }
324
325 let from_value = values.get(&edge.from).copied().unwrap_or(0.0);
326 let baseline = self
327 .find_node(&edge.from)
328 .map(|n| n.baseline_value)
329 .unwrap_or(0.0);
330
331 let delta = from_value - baseline;
333 if delta.abs() < f64::EPSILON {
334 continue;
335 }
336
337 let effect = edge.transfer.compute(delta) * edge.strength;
339 total_effect += effect;
340 has_effect = true;
341 }
342
343 if has_effect {
344 let baseline = self
345 .find_node(node_id)
346 .map(|n| n.baseline_value)
347 .unwrap_or(0.0);
348 let mut new_value = baseline + total_effect;
349
350 if let Some(node) = self.find_node(node_id) {
352 if let Some((min, max)) = node.bounds {
353 new_value = new_value.clamp(min, max);
354 }
355 }
356
357 values.insert(node_id.clone(), new_value);
358 }
359 }
360 }
361
362 values
363 }
364}
365
366#[cfg(test)]
367#[allow(clippy::unwrap_used)]
368mod tests {
369 use super::*;
370
371 fn make_node(id: &str, baseline: f64) -> CausalNode {
372 CausalNode {
373 id: id.to_string(),
374 label: id.to_string(),
375 category: NodeCategory::Operational,
376 baseline_value: baseline,
377 bounds: None,
378 interventionable: true,
379 config_bindings: vec![],
380 }
381 }
382
383 fn make_edge(from: &str, to: &str, transfer: TransferFunction) -> CausalEdge {
384 CausalEdge {
385 from: from.to_string(),
386 to: to.to_string(),
387 transfer,
388 lag_months: 0,
389 strength: 1.0,
390 mechanism: None,
391 }
392 }
393
394 #[test]
395 fn test_transfer_function_linear() {
396 let tf = TransferFunction::Linear {
397 coefficient: 0.5,
398 intercept: 1.0,
399 };
400 let result = tf.compute(2.0);
401 assert!((result - 2.0).abs() < f64::EPSILON); }
403
404 #[test]
405 fn test_transfer_function_logistic() {
406 let tf = TransferFunction::Logistic {
407 capacity: 1.0,
408 midpoint: 0.0,
409 steepness: 1.0,
410 };
411 let result = tf.compute(0.0);
413 assert!((result - 0.5).abs() < 0.001);
414 }
415
416 #[test]
417 fn test_transfer_function_exponential() {
418 let tf = TransferFunction::Exponential {
419 base: 1.0,
420 rate: 1.0,
421 };
422 let result = tf.compute(3.0);
424 assert!((result - 8.0).abs() < 0.001);
425 }
426
427 #[test]
428 fn test_transfer_function_step() {
429 let tf = TransferFunction::Step {
430 threshold: 5.0,
431 magnitude: 10.0,
432 };
433 assert!((tf.compute(3.0) - 0.0).abs() < f64::EPSILON);
434 assert!((tf.compute(6.0) - 10.0).abs() < f64::EPSILON);
435 }
436
437 #[test]
438 fn test_transfer_function_threshold() {
439 let tf = TransferFunction::Threshold {
440 threshold: 2.0,
441 magnitude: 10.0,
442 saturation: f64::INFINITY,
443 };
444 assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON); assert!((tf.compute(3.0) - 5.0).abs() < 0.001);
447 }
448
449 #[test]
450 fn test_transfer_function_decay() {
451 let tf = TransferFunction::Decay {
452 initial: 100.0,
453 decay_rate: 0.5,
454 };
455 assert!((tf.compute(0.0) - 100.0).abs() < 0.001);
457 assert!((tf.compute(1.0) - 60.653).abs() < 0.1);
459 }
460
461 #[test]
462 fn test_transfer_function_piecewise() {
463 let tf = TransferFunction::Piecewise {
464 points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0)],
465 };
466 assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
468 assert!((tf.compute(1.5) - 12.5).abs() < 0.001);
470 assert!((tf.compute(-1.0) - 0.0).abs() < 0.001);
472 assert!((tf.compute(3.0) - 15.0).abs() < 0.001);
474 }
475
476 #[test]
477 fn test_dag_validate_acyclic() {
478 let mut dag = CausalDAG {
479 nodes: vec![
480 make_node("a", 1.0),
481 make_node("b", 2.0),
482 make_node("c", 3.0),
483 ],
484 edges: vec![
485 make_edge(
486 "a",
487 "b",
488 TransferFunction::Linear {
489 coefficient: 1.0,
490 intercept: 0.0,
491 },
492 ),
493 make_edge(
494 "b",
495 "c",
496 TransferFunction::Linear {
497 coefficient: 1.0,
498 intercept: 0.0,
499 },
500 ),
501 ],
502 topological_order: vec![],
503 };
504 assert!(dag.validate().is_ok());
505 assert_eq!(dag.topological_order, vec!["a", "b", "c"]);
506 }
507
508 #[test]
509 fn test_dag_validate_cycle_detected() {
510 let mut dag = CausalDAG {
511 nodes: vec![make_node("a", 1.0), make_node("b", 2.0)],
512 edges: vec![
513 make_edge(
514 "a",
515 "b",
516 TransferFunction::Linear {
517 coefficient: 1.0,
518 intercept: 0.0,
519 },
520 ),
521 make_edge(
522 "b",
523 "a",
524 TransferFunction::Linear {
525 coefficient: 1.0,
526 intercept: 0.0,
527 },
528 ),
529 ],
530 topological_order: vec![],
531 };
532 assert!(matches!(dag.validate(), Err(CausalDAGError::CycleDetected)));
533 }
534
535 #[test]
536 fn test_dag_validate_unknown_node() {
537 let mut dag = CausalDAG {
538 nodes: vec![make_node("a", 1.0)],
539 edges: vec![make_edge(
540 "a",
541 "nonexistent",
542 TransferFunction::Linear {
543 coefficient: 1.0,
544 intercept: 0.0,
545 },
546 )],
547 topological_order: vec![],
548 };
549 assert!(matches!(
550 dag.validate(),
551 Err(CausalDAGError::UnknownNode(_))
552 ));
553 }
554
555 #[test]
556 fn test_dag_validate_duplicate_node() {
557 let mut dag = CausalDAG {
558 nodes: vec![make_node("a", 1.0), make_node("a", 2.0)],
559 edges: vec![],
560 topological_order: vec![],
561 };
562 assert!(matches!(
563 dag.validate(),
564 Err(CausalDAGError::DuplicateNode(_))
565 ));
566 }
567
568 #[test]
569 fn test_dag_propagate_chain() {
570 let mut dag = CausalDAG {
571 nodes: vec![
572 make_node("a", 10.0),
573 make_node("b", 5.0),
574 make_node("c", 0.0),
575 ],
576 edges: vec![
577 make_edge(
578 "a",
579 "b",
580 TransferFunction::Linear {
581 coefficient: 0.5,
582 intercept: 0.0,
583 },
584 ),
585 make_edge(
586 "b",
587 "c",
588 TransferFunction::Linear {
589 coefficient: 1.0,
590 intercept: 0.0,
591 },
592 ),
593 ],
594 topological_order: vec![],
595 };
596 dag.validate().unwrap();
597
598 let mut interventions = HashMap::new();
600 interventions.insert("a".to_string(), 20.0);
601
602 let result = dag.propagate(&interventions, 0);
603 assert!((result["a"] - 20.0).abs() < 0.001);
605 assert!((result["b"] - 10.0).abs() < 0.001);
607 assert!((result["c"] - 5.0).abs() < 0.001);
609 }
610
611 #[test]
612 fn test_dag_propagate_with_lag() {
613 let mut dag = CausalDAG {
614 nodes: vec![make_node("a", 10.0), make_node("b", 5.0)],
615 edges: vec![CausalEdge {
616 from: "a".to_string(),
617 to: "b".to_string(),
618 transfer: TransferFunction::Linear {
619 coefficient: 1.0,
620 intercept: 0.0,
621 },
622 lag_months: 2,
623 strength: 1.0,
624 mechanism: None,
625 }],
626 topological_order: vec![],
627 };
628 dag.validate().unwrap();
629
630 let mut interventions = HashMap::new();
631 interventions.insert("a".to_string(), 20.0);
632
633 let result = dag.propagate(&interventions, 1);
635 assert!((result["b"] - 5.0).abs() < 0.001); let result = dag.propagate(&interventions, 2);
639 assert!((result["b"] - 15.0).abs() < 0.001);
641 }
642
643 #[test]
644 fn test_dag_propagate_node_bounds_clamped() {
645 let mut dag = CausalDAG {
646 nodes: vec![make_node("a", 10.0), {
647 let mut n = make_node("b", 5.0);
648 n.bounds = Some((0.0, 8.0));
649 n
650 }],
651 edges: vec![make_edge(
652 "a",
653 "b",
654 TransferFunction::Linear {
655 coefficient: 1.0,
656 intercept: 0.0,
657 },
658 )],
659 topological_order: vec![],
660 };
661 dag.validate().unwrap();
662
663 let mut interventions = HashMap::new();
664 interventions.insert("a".to_string(), 20.0); let result = dag.propagate(&interventions, 0);
667 assert!((result["b"] - 8.0).abs() < 0.001);
669 }
670
671 #[test]
672 fn test_transfer_function_serde() {
673 let tf = TransferFunction::Linear {
674 coefficient: 0.5,
675 intercept: 1.0,
676 };
677 let json = serde_json::to_string(&tf).unwrap();
678 let deserialized: TransferFunction = serde_json::from_str(&json).unwrap();
679 assert!((deserialized.compute(2.0) - 2.0).abs() < f64::EPSILON);
680 }
681
682 #[test]
687 fn test_transfer_function_linear_zero_coefficient() {
688 let tf = TransferFunction::Linear {
689 coefficient: 0.0,
690 intercept: 5.0,
691 };
692 assert!((tf.compute(0.0) - 5.0).abs() < f64::EPSILON);
694 assert!((tf.compute(100.0) - 5.0).abs() < f64::EPSILON);
695 assert!((tf.compute(-100.0) - 5.0).abs() < f64::EPSILON);
696 }
697
698 #[test]
699 fn test_transfer_function_linear_negative_coefficient() {
700 let tf = TransferFunction::Linear {
701 coefficient: -2.0,
702 intercept: 10.0,
703 };
704 assert!((tf.compute(3.0) - 4.0).abs() < f64::EPSILON); assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON); }
707
708 #[test]
709 fn test_transfer_function_exponential_zero_input() {
710 let tf = TransferFunction::Exponential {
711 base: 5.0,
712 rate: 0.5,
713 };
714 assert!((tf.compute(0.0) - 5.0).abs() < 0.001);
716 }
717
718 #[test]
719 fn test_transfer_function_exponential_negative_rate() {
720 let tf = TransferFunction::Exponential {
721 base: 100.0,
722 rate: -0.5,
723 };
724 assert!((tf.compute(2.0) - 25.0).abs() < 0.001);
726 }
727
728 #[test]
729 fn test_transfer_function_logistic_far_from_midpoint() {
730 let tf = TransferFunction::Logistic {
731 capacity: 10.0,
732 midpoint: 5.0,
733 steepness: 2.0,
734 };
735 assert!(tf.compute(-10.0) < 0.01);
737 assert!((tf.compute(20.0) - 10.0).abs() < 0.01);
739 assert!((tf.compute(5.0) - 5.0).abs() < 0.01);
741 }
742
743 #[test]
744 fn test_transfer_function_logistic_steepness_effect() {
745 let steep = TransferFunction::Logistic {
747 capacity: 1.0,
748 midpoint: 0.0,
749 steepness: 10.0,
750 };
751 let gentle = TransferFunction::Logistic {
752 capacity: 1.0,
753 midpoint: 0.0,
754 steepness: 0.5,
755 };
756 assert!((steep.compute(0.0) - 0.5).abs() < 0.01);
758 assert!((gentle.compute(0.0) - 0.5).abs() < 0.01);
759 assert!(steep.compute(1.0) > gentle.compute(1.0));
761 }
762
763 #[test]
764 fn test_transfer_function_inverse_logistic() {
765 let tf = TransferFunction::InverseLogistic {
766 capacity: 1.0,
767 midpoint: 0.0,
768 steepness: 1.0,
769 };
770 assert!((tf.compute(0.0) - 0.5).abs() < 0.001);
772 assert!(tf.compute(10.0) < 0.01);
774 assert!((tf.compute(-10.0) - 1.0).abs() < 0.01);
776 }
777
778 #[test]
779 fn test_transfer_function_inverse_logistic_symmetry() {
780 let logistic = TransferFunction::Logistic {
781 capacity: 1.0,
782 midpoint: 0.0,
783 steepness: 1.0,
784 };
785 let inverse = TransferFunction::InverseLogistic {
786 capacity: 1.0,
787 midpoint: 0.0,
788 steepness: 1.0,
789 };
790 for x in [-5.0, -1.0, 0.0, 1.0, 5.0] {
792 let sum = logistic.compute(x) + inverse.compute(x);
793 assert!((sum - 1.0).abs() < 0.001, "Sum at x={} was {}", x, sum);
794 }
795 }
796
797 #[test]
798 fn test_transfer_function_step_at_threshold() {
799 let tf = TransferFunction::Step {
800 threshold: 5.0,
801 magnitude: 10.0,
802 };
803 assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
805 assert!((tf.compute(5.001) - 10.0).abs() < f64::EPSILON);
807 }
808
809 #[test]
810 fn test_transfer_function_step_negative_magnitude() {
811 let tf = TransferFunction::Step {
812 threshold: 0.0,
813 magnitude: -5.0,
814 };
815 assert!((tf.compute(-1.0) - 0.0).abs() < f64::EPSILON);
816 assert!((tf.compute(1.0) - (-5.0)).abs() < f64::EPSILON);
817 }
818
819 #[test]
820 fn test_transfer_function_threshold_with_saturation() {
821 let tf = TransferFunction::Threshold {
822 threshold: 2.0,
823 magnitude: 10.0,
824 saturation: 8.0,
825 };
826 assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON);
828 assert!((tf.compute(2.5) - 2.5).abs() < 0.001);
830 assert!((tf.compute(100.0) - 8.0).abs() < 0.001);
833 }
834
835 #[test]
836 fn test_transfer_function_threshold_infinite_saturation() {
837 let tf = TransferFunction::Threshold {
838 threshold: 1.0,
839 magnitude: 5.0,
840 saturation: f64::INFINITY,
841 };
842 assert!((tf.compute(100.0) - 495.0).abs() < 0.001);
845 }
846
847 #[test]
848 fn test_transfer_function_decay_large_input() {
849 let tf = TransferFunction::Decay {
850 initial: 100.0,
851 decay_rate: 1.0,
852 };
853 assert!(tf.compute(10.0) < 0.01);
855 assert!(tf.compute(20.0) < 0.0001);
856 }
857
858 #[test]
859 fn test_transfer_function_decay_zero_rate() {
860 let tf = TransferFunction::Decay {
861 initial: 50.0,
862 decay_rate: 0.0,
863 };
864 assert!((tf.compute(0.0) - 50.0).abs() < f64::EPSILON);
866 assert!((tf.compute(100.0) - 50.0).abs() < f64::EPSILON);
867 }
868
869 #[test]
870 fn test_transfer_function_piecewise_single_point() {
871 let tf = TransferFunction::Piecewise {
872 points: vec![(5.0, 42.0)],
873 };
874 assert!((tf.compute(0.0) - 42.0).abs() < f64::EPSILON);
876 assert!((tf.compute(100.0) - 42.0).abs() < f64::EPSILON);
877 }
878
879 #[test]
880 fn test_transfer_function_piecewise_empty() {
881 let tf = TransferFunction::Piecewise { points: vec![] };
882 assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
883 }
884
885 #[test]
886 fn test_transfer_function_piecewise_exact_points() {
887 let tf = TransferFunction::Piecewise {
888 points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0), (3.0, 30.0)],
889 };
890 assert!((tf.compute(0.0) - 0.0).abs() < 0.001);
892 assert!((tf.compute(1.0) - 10.0).abs() < 0.001);
893 assert!((tf.compute(2.0) - 15.0).abs() < 0.001);
894 assert!((tf.compute(3.0) - 30.0).abs() < 0.001);
895 }
896
897 #[test]
898 fn test_transfer_function_piecewise_unsorted_points() {
899 let tf = TransferFunction::Piecewise {
901 points: vec![(2.0, 20.0), (0.0, 0.0), (1.0, 10.0)],
902 };
903 assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
904 assert!((tf.compute(1.5) - 15.0).abs() < 0.001);
905 }
906}