1
2
3#[derive(Debug, Clone, Copy, PartialEq)]
5pub enum EvolutionMode {
6 Static,
8 Linear,
10 Nonlinear,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum NonlinearFn {
17 Sigmoid,
18 Tanh,
19 ExpDecay,
20}
21
22impl NonlinearFn {
23 pub fn apply(&self, x: f64, k: f64) -> f64 {
24 match self {
25 NonlinearFn::Sigmoid => 1.0 / (1.0 + (-k * x).exp()),
26 NonlinearFn::Tanh => (k * x).tanh(),
27 NonlinearFn::ExpDecay => (-k * x * x).exp(),
28 }
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq)]
34pub struct SheafConfig {
35 pub model: EvolutionMode,
36 pub r0: f64,
38 pub alpha: f64,
40 pub nonlin: NonlinearFn,
42 pub nonlin_k: f64,
44}
45
46impl Default for SheafConfig {
47 fn default() -> Self {
48 SheafConfig {
49 model: EvolutionMode::Static,
50 r0: 1.0,
51 alpha: 0.0,
52 nonlin: NonlinearFn::Sigmoid,
53 nonlin_k: 1.0,
54 }
55 }
56}
57
58impl SheafConfig {
59 pub fn eval_restriction(&self, flow_energy: f64) -> f64 {
61 match self.model {
62 EvolutionMode::Static => self.r0,
63 EvolutionMode::Linear => self.r0 + self.alpha * flow_energy,
64 EvolutionMode::Nonlinear => self.r0 * self.nonlin.apply(flow_energy, self.nonlin_k),
65 }
66 }
67
68 pub fn static_sheaf(r0: f64) -> Self {
70 SheafConfig {
71 model: EvolutionMode::Static,
72 r0,
73 alpha: 0.0,
74 nonlin: NonlinearFn::Sigmoid,
75 nonlin_k: 1.0,
76 }
77 }
78
79 pub fn linear(r0: f64, alpha: f64) -> Self {
81 SheafConfig {
82 model: EvolutionMode::Linear,
83 r0,
84 alpha,
85 nonlin: NonlinearFn::Sigmoid,
86 nonlin_k: 1.0,
87 }
88 }
89
90 pub fn nonlinear(r0: f64, nonlin: NonlinearFn, k: f64) -> Self {
92 SheafConfig {
93 model: EvolutionMode::Nonlinear,
94 r0,
95 alpha: 0.0,
96 nonlin,
97 nonlin_k: k,
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn test_sigmoid_zero() {
108 let v = NonlinearFn::Sigmoid.apply(0.0, 1.0);
109 assert!((v - 0.5).abs() < 1e-6, "sigmoid(0) = 0.5");
110 }
111
112 #[test]
113 fn test_tanh_zero() {
114 let v = NonlinearFn::Tanh.apply(0.0, 1.0);
115 assert!((v - 0.0).abs() < 1e-6, "tanh(0) = 0");
116 }
117
118 #[test]
119 fn test_expdecay_zero() {
120 let v = NonlinearFn::ExpDecay.apply(0.0, 1.0);
121 assert!((v - 1.0).abs() < 1e-6, "exp(-0) = 1");
122 }
123
124 #[test]
125 fn test_sigmoid_large_positive() {
126 let v = NonlinearFn::Sigmoid.apply(100.0, 1.0);
127 assert!((v - 1.0).abs() < 1e-6, "sigmoid(∞) → 1");
128 }
129
130 #[test]
131 fn test_sigmoid_large_negative() {
132 let v = NonlinearFn::Sigmoid.apply(-100.0, 1.0);
133 assert!((v - 0.0).abs() < 1e-6, "sigmoid(-∞) → 0");
134 }
135
136 #[test]
137 fn test_eval_static() {
138 let cfg = SheafConfig::static_sheaf(3.14);
139 let r = cfg.eval_restriction(5.0);
140 assert!((r - 3.14).abs() < 1e-6, "static R = R₀");
141 }
142
143 #[test]
144 fn test_eval_linear() {
145 let cfg = SheafConfig::linear(1.0, 0.5);
146 let r = cfg.eval_restriction(2.0);
147 assert!((r - 2.0).abs() < 1e-6, "linear: 1.0 + 0.5*2.0 = 2.0");
148 }
149
150 #[test]
151 fn test_eval_nonlinear_tanh() {
152 let cfg = SheafConfig::nonlinear(2.0, NonlinearFn::Tanh, 1.0);
153 let r = cfg.eval_restriction(1.0);
154 let expected = 2.0 * (1.0_f64).tanh();
155 assert!((r - expected).abs() < 1e-6, "nonlinear: 2.0*tanh(1.0)");
156 }
157
158 #[test]
159 fn test_eval_nonlinear_sigmoid() {
160 let cfg = SheafConfig::nonlinear(2.0, NonlinearFn::Sigmoid, 1.0);
161 let r = cfg.eval_restriction(0.0);
162 assert!((r - 1.0).abs() < 1e-6, "nonlinear sigmoid(0) = 1.0");
164 }
165
166 #[test]
167 fn test_default() {
168 let cfg: SheafConfig = Default::default();
169 assert_eq!(cfg.model, EvolutionMode::Static);
170 assert!((cfg.r0 - 1.0).abs() < 1e-12);
171 assert!((cfg.alpha - 0.0).abs() < 1e-12);
172 }
173}