1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3use thiserror::Error;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CausalDAG {
9 pub nodes: Vec<CausalNode>,
10 pub edges: Vec<CausalEdge>,
11 #[serde(skip)]
13 pub topological_order: Vec<String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CausalNode {
18 pub id: String,
20 pub label: String,
21 pub category: NodeCategory,
22 pub baseline_value: f64,
24 pub bounds: Option<(f64, f64)>,
26 #[serde(default = "default_true")]
28 pub interventionable: bool,
29 #[serde(default)]
31 pub config_bindings: Vec<String>,
32}
33
34fn default_true() -> bool {
35 true
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum NodeCategory {
41 Macro,
42 Operational,
43 Control,
44 Financial,
45 Behavioral,
46 Regulatory,
47 Outcome,
48 Audit,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct CausalEdge {
53 pub from: String,
54 pub to: String,
55 pub transfer: TransferFunction,
56 #[serde(default)]
58 pub lag_months: u32,
59 #[serde(default = "default_strength")]
61 pub strength: f64,
62 pub mechanism: Option<String>,
64}
65
66fn default_strength() -> f64 {
67 1.0
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(tag = "type", rename_all = "snake_case")]
72pub enum TransferFunction {
73 Linear {
75 coefficient: f64,
76 #[serde(default)]
77 intercept: f64,
78 },
79 Exponential { base: f64, rate: f64 },
81 Logistic {
83 capacity: f64,
84 midpoint: f64,
85 steepness: f64,
86 },
87 InverseLogistic {
89 capacity: f64,
90 midpoint: f64,
91 steepness: f64,
92 },
93 Step { threshold: f64, magnitude: f64 },
95 Threshold {
97 threshold: f64,
98 magnitude: f64,
99 #[serde(default = "default_saturation")]
100 saturation: f64,
101 },
102 Decay { initial: f64, decay_rate: f64 },
104 Piecewise { points: Vec<(f64, f64)> },
106}
107
108fn default_saturation() -> f64 {
109 f64::INFINITY
110}
111
112impl TransferFunction {
113 pub fn compute(&self, input: f64) -> f64 {
115 match self {
116 TransferFunction::Linear {
117 coefficient,
118 intercept,
119 } => input * coefficient + intercept,
120
121 TransferFunction::Exponential { base, rate } => base * (1.0 + rate).powf(input),
122
123 TransferFunction::Logistic {
124 capacity,
125 midpoint,
126 steepness,
127 } => capacity / (1.0 + (-steepness * (input - midpoint)).exp()),
128
129 TransferFunction::InverseLogistic {
130 capacity,
131 midpoint,
132 steepness,
133 } => capacity / (1.0 + (steepness * (input - midpoint)).exp()),
134
135 TransferFunction::Step {
136 threshold,
137 magnitude,
138 } => {
139 if input > *threshold {
140 *magnitude
141 } else {
142 0.0
143 }
144 }
145
146 TransferFunction::Threshold {
147 threshold,
148 magnitude,
149 saturation,
150 } => {
151 if input > *threshold {
152 (magnitude * (input - threshold) / threshold).min(*saturation)
153 } else {
154 0.0
155 }
156 }
157
158 TransferFunction::Decay {
159 initial,
160 decay_rate,
161 } => initial * (-decay_rate * input).exp(),
162
163 TransferFunction::Piecewise { points } => {
164 if points.is_empty() {
165 return 0.0;
166 }
167 if points.len() == 1 {
168 return points[0].1;
169 }
170
171 let mut sorted = points.clone();
173 sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
174
175 if input <= sorted[0].0 {
177 return sorted[0].1;
178 }
179 if input >= sorted[sorted.len() - 1].0 {
180 return sorted[sorted.len() - 1].1;
181 }
182
183 for window in sorted.windows(2) {
185 let (x0, y0) = window[0];
186 let (x1, y1) = window[1];
187 if input >= x0 && input <= x1 {
188 let t = (input - x0) / (x1 - x0);
189 return y0 + t * (y1 - y0);
190 }
191 }
192
193 sorted[sorted.len() - 1].1
194 }
195 }
196 }
197}
198
199#[derive(Debug, Error)]
201pub enum CausalDAGError {
202 #[error("cycle detected in causal DAG")]
203 CycleDetected,
204 #[error("unknown node referenced in edge: {0}")]
205 UnknownNode(String),
206 #[error("duplicate node ID: {0}")]
207 DuplicateNode(String),
208 #[error("node '{0}' is not interventionable")]
209 NonInterventionable(String),
210}
211
212impl CausalDAG {
213 pub fn validate(&mut self) -> Result<(), CausalDAGError> {
215 let node_ids: HashSet<&str> = self.nodes.iter().map(|n| n.id.as_str()).collect();
216
217 let mut seen = HashSet::new();
219 for node in &self.nodes {
220 if !seen.insert(&node.id) {
221 return Err(CausalDAGError::DuplicateNode(node.id.clone()));
222 }
223 }
224
225 for edge in &self.edges {
227 if !node_ids.contains(edge.from.as_str()) {
228 return Err(CausalDAGError::UnknownNode(edge.from.clone()));
229 }
230 if !node_ids.contains(edge.to.as_str()) {
231 return Err(CausalDAGError::UnknownNode(edge.to.clone()));
232 }
233 }
234
235 let mut in_degree: HashMap<&str, usize> = HashMap::new();
237 let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
238
239 for node in &self.nodes {
240 in_degree.insert(&node.id, 0);
241 adjacency.insert(&node.id, Vec::new());
242 }
243
244 for edge in &self.edges {
245 *in_degree.entry(&edge.to).or_insert(0) += 1;
246 adjacency.entry(&edge.from).or_default().push(&edge.to);
247 }
248
249 let mut queue: VecDeque<&str> = VecDeque::new();
250 for (node, °ree) in &in_degree {
251 if degree == 0 {
252 queue.push_back(node);
253 }
254 }
255
256 let mut order = Vec::new();
257 while let Some(node) = queue.pop_front() {
258 order.push(node.to_string());
259 if let Some(neighbors) = adjacency.get(node) {
260 for &neighbor in neighbors {
261 if let Some(degree) = in_degree.get_mut(neighbor) {
262 *degree -= 1;
263 if *degree == 0 {
264 queue.push_back(neighbor);
265 }
266 }
267 }
268 }
269 }
270
271 if order.len() != self.nodes.len() {
272 return Err(CausalDAGError::CycleDetected);
273 }
274
275 self.topological_order = order;
276 Ok(())
277 }
278
279 pub fn find_node(&self, id: &str) -> Option<&CausalNode> {
281 self.nodes.iter().find(|n| n.id == id)
282 }
283
284 pub fn propagate(
287 &self,
288 interventions: &HashMap<String, f64>,
289 month: u32,
290 ) -> HashMap<String, f64> {
291 let mut values: HashMap<String, f64> = HashMap::new();
292
293 for node in &self.nodes {
295 values.insert(node.id.clone(), node.baseline_value);
296 }
297
298 for (node_id, value) in interventions {
300 values.insert(node_id.clone(), *value);
301 }
302
303 let mut incoming: HashMap<&str, Vec<&CausalEdge>> = HashMap::new();
305 for edge in &self.edges {
306 incoming.entry(&edge.to).or_default().push(edge);
307 }
308
309 for node_id in &self.topological_order {
311 if interventions.contains_key(node_id) {
313 continue;
314 }
315
316 if let Some(edges) = incoming.get(node_id.as_str()) {
317 let mut total_effect = 0.0;
318 let mut has_effect = false;
319
320 for edge in edges {
321 if month < edge.lag_months {
323 continue;
324 }
325
326 let from_value = values.get(&edge.from).copied().unwrap_or(0.0);
327 let baseline = self
328 .find_node(&edge.from)
329 .map(|n| n.baseline_value)
330 .unwrap_or(0.0);
331
332 let delta = from_value - baseline;
334 if delta.abs() < f64::EPSILON {
335 continue;
336 }
337
338 let effect = edge.transfer.compute(delta) * edge.strength;
340 total_effect += effect;
341 has_effect = true;
342 }
343
344 if has_effect {
345 let baseline = self
346 .find_node(node_id)
347 .map(|n| n.baseline_value)
348 .unwrap_or(0.0);
349 let mut new_value = baseline + total_effect;
350
351 if let Some(node) = self.find_node(node_id) {
353 if let Some((min, max)) = node.bounds {
354 new_value = new_value.clamp(min, max);
355 }
356 }
357
358 values.insert(node_id.clone(), new_value);
359 }
360 }
361 }
362
363 values
364 }
365}
366
367#[cfg(test)]
368#[allow(clippy::unwrap_used)]
369mod tests {
370 use super::*;
371
372 fn make_node(id: &str, baseline: f64) -> CausalNode {
373 CausalNode {
374 id: id.to_string(),
375 label: id.to_string(),
376 category: NodeCategory::Operational,
377 baseline_value: baseline,
378 bounds: None,
379 interventionable: true,
380 config_bindings: vec![],
381 }
382 }
383
384 fn make_edge(from: &str, to: &str, transfer: TransferFunction) -> CausalEdge {
385 CausalEdge {
386 from: from.to_string(),
387 to: to.to_string(),
388 transfer,
389 lag_months: 0,
390 strength: 1.0,
391 mechanism: None,
392 }
393 }
394
395 #[test]
396 fn test_transfer_function_linear() {
397 let tf = TransferFunction::Linear {
398 coefficient: 0.5,
399 intercept: 1.0,
400 };
401 let result = tf.compute(2.0);
402 assert!((result - 2.0).abs() < f64::EPSILON); }
404
405 #[test]
406 fn test_transfer_function_logistic() {
407 let tf = TransferFunction::Logistic {
408 capacity: 1.0,
409 midpoint: 0.0,
410 steepness: 1.0,
411 };
412 let result = tf.compute(0.0);
414 assert!((result - 0.5).abs() < 0.001);
415 }
416
417 #[test]
418 fn test_transfer_function_exponential() {
419 let tf = TransferFunction::Exponential {
420 base: 1.0,
421 rate: 1.0,
422 };
423 let result = tf.compute(3.0);
425 assert!((result - 8.0).abs() < 0.001);
426 }
427
428 #[test]
429 fn test_transfer_function_step() {
430 let tf = TransferFunction::Step {
431 threshold: 5.0,
432 magnitude: 10.0,
433 };
434 assert!((tf.compute(3.0) - 0.0).abs() < f64::EPSILON);
435 assert!((tf.compute(6.0) - 10.0).abs() < f64::EPSILON);
436 }
437
438 #[test]
439 fn test_transfer_function_threshold() {
440 let tf = TransferFunction::Threshold {
441 threshold: 2.0,
442 magnitude: 10.0,
443 saturation: f64::INFINITY,
444 };
445 assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON); assert!((tf.compute(3.0) - 5.0).abs() < 0.001);
448 }
449
450 #[test]
451 fn test_transfer_function_decay() {
452 let tf = TransferFunction::Decay {
453 initial: 100.0,
454 decay_rate: 0.5,
455 };
456 assert!((tf.compute(0.0) - 100.0).abs() < 0.001);
458 assert!((tf.compute(1.0) - 60.653).abs() < 0.1);
460 }
461
462 #[test]
463 fn test_transfer_function_piecewise() {
464 let tf = TransferFunction::Piecewise {
465 points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0)],
466 };
467 assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
469 assert!((tf.compute(1.5) - 12.5).abs() < 0.001);
471 assert!((tf.compute(-1.0) - 0.0).abs() < 0.001);
473 assert!((tf.compute(3.0) - 15.0).abs() < 0.001);
475 }
476
477 #[test]
478 fn test_dag_validate_acyclic() {
479 let mut dag = CausalDAG {
480 nodes: vec![
481 make_node("a", 1.0),
482 make_node("b", 2.0),
483 make_node("c", 3.0),
484 ],
485 edges: vec![
486 make_edge(
487 "a",
488 "b",
489 TransferFunction::Linear {
490 coefficient: 1.0,
491 intercept: 0.0,
492 },
493 ),
494 make_edge(
495 "b",
496 "c",
497 TransferFunction::Linear {
498 coefficient: 1.0,
499 intercept: 0.0,
500 },
501 ),
502 ],
503 topological_order: vec![],
504 };
505 assert!(dag.validate().is_ok());
506 assert_eq!(dag.topological_order, vec!["a", "b", "c"]);
507 }
508
509 #[test]
510 fn test_dag_validate_cycle_detected() {
511 let mut dag = CausalDAG {
512 nodes: vec![make_node("a", 1.0), make_node("b", 2.0)],
513 edges: vec![
514 make_edge(
515 "a",
516 "b",
517 TransferFunction::Linear {
518 coefficient: 1.0,
519 intercept: 0.0,
520 },
521 ),
522 make_edge(
523 "b",
524 "a",
525 TransferFunction::Linear {
526 coefficient: 1.0,
527 intercept: 0.0,
528 },
529 ),
530 ],
531 topological_order: vec![],
532 };
533 assert!(matches!(dag.validate(), Err(CausalDAGError::CycleDetected)));
534 }
535
536 #[test]
537 fn test_dag_validate_unknown_node() {
538 let mut dag = CausalDAG {
539 nodes: vec![make_node("a", 1.0)],
540 edges: vec![make_edge(
541 "a",
542 "nonexistent",
543 TransferFunction::Linear {
544 coefficient: 1.0,
545 intercept: 0.0,
546 },
547 )],
548 topological_order: vec![],
549 };
550 assert!(matches!(
551 dag.validate(),
552 Err(CausalDAGError::UnknownNode(_))
553 ));
554 }
555
556 #[test]
557 fn test_dag_validate_duplicate_node() {
558 let mut dag = CausalDAG {
559 nodes: vec![make_node("a", 1.0), make_node("a", 2.0)],
560 edges: vec![],
561 topological_order: vec![],
562 };
563 assert!(matches!(
564 dag.validate(),
565 Err(CausalDAGError::DuplicateNode(_))
566 ));
567 }
568
569 #[test]
570 fn test_dag_propagate_chain() {
571 let mut dag = CausalDAG {
572 nodes: vec![
573 make_node("a", 10.0),
574 make_node("b", 5.0),
575 make_node("c", 0.0),
576 ],
577 edges: vec![
578 make_edge(
579 "a",
580 "b",
581 TransferFunction::Linear {
582 coefficient: 0.5,
583 intercept: 0.0,
584 },
585 ),
586 make_edge(
587 "b",
588 "c",
589 TransferFunction::Linear {
590 coefficient: 1.0,
591 intercept: 0.0,
592 },
593 ),
594 ],
595 topological_order: vec![],
596 };
597 dag.validate().unwrap();
598
599 let mut interventions = HashMap::new();
601 interventions.insert("a".to_string(), 20.0);
602
603 let result = dag.propagate(&interventions, 0);
604 assert!((result["a"] - 20.0).abs() < 0.001);
606 assert!((result["b"] - 10.0).abs() < 0.001);
608 assert!((result["c"] - 5.0).abs() < 0.001);
610 }
611
612 #[test]
613 fn test_dag_propagate_with_lag() {
614 let mut dag = CausalDAG {
615 nodes: vec![make_node("a", 10.0), make_node("b", 5.0)],
616 edges: vec![CausalEdge {
617 from: "a".to_string(),
618 to: "b".to_string(),
619 transfer: TransferFunction::Linear {
620 coefficient: 1.0,
621 intercept: 0.0,
622 },
623 lag_months: 2,
624 strength: 1.0,
625 mechanism: None,
626 }],
627 topological_order: vec![],
628 };
629 dag.validate().unwrap();
630
631 let mut interventions = HashMap::new();
632 interventions.insert("a".to_string(), 20.0);
633
634 let result = dag.propagate(&interventions, 1);
636 assert!((result["b"] - 5.0).abs() < 0.001); let result = dag.propagate(&interventions, 2);
640 assert!((result["b"] - 15.0).abs() < 0.001);
642 }
643
644 #[test]
645 fn test_dag_propagate_node_bounds_clamped() {
646 let mut dag = CausalDAG {
647 nodes: vec![make_node("a", 10.0), {
648 let mut n = make_node("b", 5.0);
649 n.bounds = Some((0.0, 8.0));
650 n
651 }],
652 edges: vec![make_edge(
653 "a",
654 "b",
655 TransferFunction::Linear {
656 coefficient: 1.0,
657 intercept: 0.0,
658 },
659 )],
660 topological_order: vec![],
661 };
662 dag.validate().unwrap();
663
664 let mut interventions = HashMap::new();
665 interventions.insert("a".to_string(), 20.0); let result = dag.propagate(&interventions, 0);
668 assert!((result["b"] - 8.0).abs() < 0.001);
670 }
671
672 #[test]
673 fn test_transfer_function_serde() {
674 let tf = TransferFunction::Linear {
675 coefficient: 0.5,
676 intercept: 1.0,
677 };
678 let json = serde_json::to_string(&tf).unwrap();
679 let deserialized: TransferFunction = serde_json::from_str(&json).unwrap();
680 assert!((deserialized.compute(2.0) - 2.0).abs() < f64::EPSILON);
681 }
682
683 #[test]
688 fn test_transfer_function_linear_zero_coefficient() {
689 let tf = TransferFunction::Linear {
690 coefficient: 0.0,
691 intercept: 5.0,
692 };
693 assert!((tf.compute(0.0) - 5.0).abs() < f64::EPSILON);
695 assert!((tf.compute(100.0) - 5.0).abs() < f64::EPSILON);
696 assert!((tf.compute(-100.0) - 5.0).abs() < f64::EPSILON);
697 }
698
699 #[test]
700 fn test_transfer_function_linear_negative_coefficient() {
701 let tf = TransferFunction::Linear {
702 coefficient: -2.0,
703 intercept: 10.0,
704 };
705 assert!((tf.compute(3.0) - 4.0).abs() < f64::EPSILON); assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON); }
708
709 #[test]
710 fn test_transfer_function_exponential_zero_input() {
711 let tf = TransferFunction::Exponential {
712 base: 5.0,
713 rate: 0.5,
714 };
715 assert!((tf.compute(0.0) - 5.0).abs() < 0.001);
717 }
718
719 #[test]
720 fn test_transfer_function_exponential_negative_rate() {
721 let tf = TransferFunction::Exponential {
722 base: 100.0,
723 rate: -0.5,
724 };
725 assert!((tf.compute(2.0) - 25.0).abs() < 0.001);
727 }
728
729 #[test]
730 fn test_transfer_function_logistic_far_from_midpoint() {
731 let tf = TransferFunction::Logistic {
732 capacity: 10.0,
733 midpoint: 5.0,
734 steepness: 2.0,
735 };
736 assert!(tf.compute(-10.0) < 0.01);
738 assert!((tf.compute(20.0) - 10.0).abs() < 0.01);
740 assert!((tf.compute(5.0) - 5.0).abs() < 0.01);
742 }
743
744 #[test]
745 fn test_transfer_function_logistic_steepness_effect() {
746 let steep = TransferFunction::Logistic {
748 capacity: 1.0,
749 midpoint: 0.0,
750 steepness: 10.0,
751 };
752 let gentle = TransferFunction::Logistic {
753 capacity: 1.0,
754 midpoint: 0.0,
755 steepness: 0.5,
756 };
757 assert!((steep.compute(0.0) - 0.5).abs() < 0.01);
759 assert!((gentle.compute(0.0) - 0.5).abs() < 0.01);
760 assert!(steep.compute(1.0) > gentle.compute(1.0));
762 }
763
764 #[test]
765 fn test_transfer_function_inverse_logistic() {
766 let tf = TransferFunction::InverseLogistic {
767 capacity: 1.0,
768 midpoint: 0.0,
769 steepness: 1.0,
770 };
771 assert!((tf.compute(0.0) - 0.5).abs() < 0.001);
773 assert!(tf.compute(10.0) < 0.01);
775 assert!((tf.compute(-10.0) - 1.0).abs() < 0.01);
777 }
778
779 #[test]
780 fn test_transfer_function_inverse_logistic_symmetry() {
781 let logistic = TransferFunction::Logistic {
782 capacity: 1.0,
783 midpoint: 0.0,
784 steepness: 1.0,
785 };
786 let inverse = TransferFunction::InverseLogistic {
787 capacity: 1.0,
788 midpoint: 0.0,
789 steepness: 1.0,
790 };
791 for x in [-5.0, -1.0, 0.0, 1.0, 5.0] {
793 let sum = logistic.compute(x) + inverse.compute(x);
794 assert!((sum - 1.0).abs() < 0.001, "Sum at x={} was {}", x, sum);
795 }
796 }
797
798 #[test]
799 fn test_transfer_function_step_at_threshold() {
800 let tf = TransferFunction::Step {
801 threshold: 5.0,
802 magnitude: 10.0,
803 };
804 assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
806 assert!((tf.compute(5.001) - 10.0).abs() < f64::EPSILON);
808 }
809
810 #[test]
811 fn test_transfer_function_step_negative_magnitude() {
812 let tf = TransferFunction::Step {
813 threshold: 0.0,
814 magnitude: -5.0,
815 };
816 assert!((tf.compute(-1.0) - 0.0).abs() < f64::EPSILON);
817 assert!((tf.compute(1.0) - (-5.0)).abs() < f64::EPSILON);
818 }
819
820 #[test]
821 fn test_transfer_function_threshold_with_saturation() {
822 let tf = TransferFunction::Threshold {
823 threshold: 2.0,
824 magnitude: 10.0,
825 saturation: 8.0,
826 };
827 assert!((tf.compute(1.0) - 0.0).abs() < f64::EPSILON);
829 assert!((tf.compute(2.5) - 2.5).abs() < 0.001);
831 assert!((tf.compute(100.0) - 8.0).abs() < 0.001);
834 }
835
836 #[test]
837 fn test_transfer_function_threshold_infinite_saturation() {
838 let tf = TransferFunction::Threshold {
839 threshold: 1.0,
840 magnitude: 5.0,
841 saturation: f64::INFINITY,
842 };
843 assert!((tf.compute(100.0) - 495.0).abs() < 0.001);
846 }
847
848 #[test]
849 fn test_transfer_function_decay_large_input() {
850 let tf = TransferFunction::Decay {
851 initial: 100.0,
852 decay_rate: 1.0,
853 };
854 assert!(tf.compute(10.0) < 0.01);
856 assert!(tf.compute(20.0) < 0.0001);
857 }
858
859 #[test]
860 fn test_transfer_function_decay_zero_rate() {
861 let tf = TransferFunction::Decay {
862 initial: 50.0,
863 decay_rate: 0.0,
864 };
865 assert!((tf.compute(0.0) - 50.0).abs() < f64::EPSILON);
867 assert!((tf.compute(100.0) - 50.0).abs() < f64::EPSILON);
868 }
869
870 #[test]
871 fn test_transfer_function_piecewise_single_point() {
872 let tf = TransferFunction::Piecewise {
873 points: vec![(5.0, 42.0)],
874 };
875 assert!((tf.compute(0.0) - 42.0).abs() < f64::EPSILON);
877 assert!((tf.compute(100.0) - 42.0).abs() < f64::EPSILON);
878 }
879
880 #[test]
881 fn test_transfer_function_piecewise_empty() {
882 let tf = TransferFunction::Piecewise { points: vec![] };
883 assert!((tf.compute(5.0) - 0.0).abs() < f64::EPSILON);
884 }
885
886 #[test]
887 fn test_transfer_function_piecewise_exact_points() {
888 let tf = TransferFunction::Piecewise {
889 points: vec![(0.0, 0.0), (1.0, 10.0), (2.0, 15.0), (3.0, 30.0)],
890 };
891 assert!((tf.compute(0.0) - 0.0).abs() < 0.001);
893 assert!((tf.compute(1.0) - 10.0).abs() < 0.001);
894 assert!((tf.compute(2.0) - 15.0).abs() < 0.001);
895 assert!((tf.compute(3.0) - 30.0).abs() < 0.001);
896 }
897
898 #[test]
899 fn test_transfer_function_piecewise_unsorted_points() {
900 let tf = TransferFunction::Piecewise {
902 points: vec![(2.0, 20.0), (0.0, 0.0), (1.0, 10.0)],
903 };
904 assert!((tf.compute(0.5) - 5.0).abs() < 0.001);
905 assert!((tf.compute(1.5) - 15.0).abs() < 0.001);
906 }
907}