1use pflow_core::State;
4
5use crate::methods;
6use crate::ode::{copy_state, Options, Problem, Solution};
7
8#[derive(Debug, Clone)]
10pub struct EquilibriumOptions {
11 pub tolerance: f64,
12 pub consecutive_steps: usize,
13 pub min_time: f64,
14 pub check_interval: usize,
15}
16
17impl EquilibriumOptions {
18 pub fn default_opts() -> Self {
19 Self {
20 tolerance: 1e-6,
21 consecutive_steps: 5,
22 min_time: 0.1,
23 check_interval: 10,
24 }
25 }
26
27 pub fn fast() -> Self {
28 Self {
29 tolerance: 1e-4,
30 consecutive_steps: 3,
31 min_time: 0.01,
32 check_interval: 5,
33 }
34 }
35
36 pub fn strict() -> Self {
37 Self {
38 tolerance: 1e-9,
39 consecutive_steps: 10,
40 min_time: 1.0,
41 check_interval: 1,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct EquilibriumResult {
49 pub reached: bool,
50 pub time: f64,
51 pub state: State,
52 pub max_change: f64,
53 pub steps: usize,
54 pub reason: String,
55}
56
57pub fn solve_until_equilibrium(
59 prob: &Problem,
60 solver: &methods::Solver,
61 opts: &Options,
62 eq_opts: &EquilibriumOptions,
63) -> (Solution, EquilibriumResult) {
64 let dt = opts.dt;
65 let dtmin = opts.dtmin;
66 let dtmax = opts.dtmax;
67 let abstol = opts.abstol;
68 let reltol = opts.reltol;
69 let maxiters = opts.maxiters;
70 let adaptive = opts.adaptive;
71
72 let t0 = prob.tspan[0];
73 let tf = prob.tspan[1];
74 let f = &prob.f;
75 let state_labels = &prob.state_labels;
76
77 let mut t_out = vec![t0];
78 let mut u_out = vec![copy_state(&prob.u0)];
79 let mut tcur = t0;
80 let mut ucur = copy_state(&prob.u0);
81 let mut dtcur = dt;
82 let mut nsteps = 0usize;
83 let mut consecutive_small = 0usize;
84 let mut check_counter = 0usize;
85
86 let mut eq_result = EquilibriumResult {
87 reached: false,
88 time: 0.0,
89 state: State::new(),
90 max_change: 0.0,
91 steps: 0,
92 reason: "time_exhausted".into(),
93 };
94
95 while tcur < tf && nsteps < maxiters {
96 if tcur + dtcur > tf {
97 dtcur = tf - tcur;
98 }
99
100 let num_stages = solver.c.len();
102 let mut k: Vec<State> = Vec::with_capacity(num_stages);
103 k.push(f(tcur, &ucur));
104
105 for stage in 1..num_stages {
106 let tstage = tcur + solver.c[stage] * dtcur;
107 let mut ustage = copy_state(&ucur);
108 for key in state_labels {
109 for j in 0..stage {
110 let aj = if stage < solver.a.len() && j < solver.a[stage].len() {
111 solver.a[stage][j]
112 } else {
113 0.0
114 };
115 if let (Some(us), Some(kj)) = (ustage.get_mut(key), k[j].get(key)) {
116 *us += dtcur * aj * kj;
117 }
118 }
119 }
120 k.push(f(tstage, &ustage));
121 }
122
123 let mut unext = copy_state(&ucur);
124 for key in state_labels {
125 for j in 0..solver.b.len() {
126 if let (Some(un), Some(kj)) = (unext.get_mut(key), k[j].get(key)) {
127 *un += dtcur * solver.b[j] * kj;
128 }
129 }
130 }
131
132 let mut err = 0.0;
133 if adaptive {
134 for key in state_labels {
135 let mut errest = 0.0;
136 for j in 0..solver.b_hat.len() {
137 if let Some(kj) = k[j].get(key) {
138 errest += dtcur * solver.b_hat[j] * kj;
139 }
140 }
141 let uc = ucur.get(key).copied().unwrap_or(0.0);
142 let un = unext.get(key).copied().unwrap_or(0.0);
143 let mut scale = abstol + reltol * uc.abs().max(un.abs());
144 if scale == 0.0 {
145 scale = abstol;
146 }
147 let val = errest.abs() / scale;
148 if val > err {
149 err = val;
150 }
151 }
152 }
153
154 if !adaptive || err <= 1.0 || dtcur <= dtmin {
155 tcur += dtcur;
156 ucur = unext;
157 t_out.push(tcur);
158 u_out.push(copy_state(&ucur));
159 nsteps += 1;
160
161 check_counter += 1;
163 if tcur >= t0 + eq_opts.min_time
164 && (eq_opts.check_interval == 0 || check_counter >= eq_opts.check_interval)
165 {
166 check_counter = 0;
167 let max_change = compute_max_change(&k[0]);
168
169 if max_change < eq_opts.tolerance {
170 consecutive_small += 1;
171 if consecutive_small >= eq_opts.consecutive_steps {
172 eq_result.reached = true;
173 eq_result.time = tcur;
174 eq_result.state = copy_state(&ucur);
175 eq_result.max_change = max_change;
176 eq_result.steps = nsteps;
177 eq_result.reason = "equilibrium_reached".into();
178 break;
179 }
180 } else {
181 consecutive_small = 0;
182 }
183 }
184
185 if adaptive && err > 0.0 {
186 let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
187 let factor = factor.min(5.0);
188 dtcur = dtmax.min(dtmin.max(dtcur * factor));
189 }
190 } else {
191 let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
192 let factor = factor.max(0.1);
193 dtcur = dtmin.max(dtcur * factor);
194 }
195 }
196
197 if nsteps >= maxiters {
198 eq_result.reason = "max_iterations".into();
199 }
200
201 eq_result.steps = nsteps;
202 if !eq_result.reached {
203 eq_result.time = tcur;
204 eq_result.state = copy_state(&ucur);
205 if !u_out.is_empty() {
206 let du = f(tcur, &ucur);
207 eq_result.max_change = compute_max_change(&du);
208 }
209 }
210
211 let sol = Solution {
212 t: t_out,
213 u: u_out,
214 state_labels: state_labels.clone(),
215 };
216
217 (sol, eq_result)
218}
219
220fn compute_max_change(du: &State) -> f64 {
221 du.values().map(|v| v.abs()).fold(0.0f64, f64::max)
222}
223
224pub fn is_equilibrium(prob: &Problem, state: &State, tolerance: f64) -> bool {
226 let du = (prob.f)(0.0, state);
227 compute_max_change(&du) < tolerance
228}
229
230pub fn find_equilibrium(prob: &Problem) -> (State, bool) {
232 let (_, result) = solve_until_equilibrium(
233 prob,
234 &methods::tsit5(),
235 &Options::default_opts(),
236 &EquilibriumOptions::default_opts(),
237 );
238 (result.state, result.reached)
239}
240
241pub fn find_equilibrium_fast(prob: &Problem) -> (State, bool) {
243 let (sol, result) = solve_until_equilibrium(
244 prob,
245 &methods::tsit5(),
246 &Options::fast(),
247 &EquilibriumOptions::fast(),
248 );
249 if result.reached {
250 (result.state, true)
251 } else {
252 (
253 sol.get_final_state().cloned().unwrap_or_default(),
254 false,
255 )
256 }
257}
258
259pub fn find_equilibrium_accurate(prob: &Problem) -> (State, bool) {
261 let (_, result) = solve_until_equilibrium(
262 prob,
263 &methods::tsit5(),
264 &Options::accurate(),
265 &EquilibriumOptions::strict(),
266 );
267 (result.state, result.reached)
268}
269
270#[derive(Debug, Clone)]
272pub struct OptionPair {
273 pub solver: Options,
274 pub equilibrium: EquilibriumOptions,
275}
276
277impl OptionPair {
278 pub fn game_ai() -> Self {
280 Self {
281 solver: Options::game_ai(),
282 equilibrium: EquilibriumOptions {
283 tolerance: 1e-3,
284 consecutive_steps: 2,
285 min_time: 0.01,
286 check_interval: 3,
287 },
288 }
289 }
290
291 pub fn epidemic() -> Self {
293 Self {
294 solver: Options::epidemic(),
295 equilibrium: EquilibriumOptions::default_opts(),
296 }
297 }
298
299 pub fn workflow() -> Self {
301 Self {
302 solver: Options::workflow(),
303 equilibrium: EquilibriumOptions {
304 tolerance: 1e-4,
305 consecutive_steps: 3,
306 min_time: 0.5,
307 check_interval: 5,
308 },
309 }
310 }
311
312 pub fn long_run() -> Self {
314 Self {
315 solver: Options::long_run(),
316 equilibrium: EquilibriumOptions::strict(),
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use pflow_core::PetriNet;
325
326 #[test]
327 fn test_sir_equilibrium() {
328 let (net, rates) = PetriNet::build().sir(999.0, 1.0, 0.0).with_rates(1.0);
329
330 let state = net.set_state(None);
331 let prob = Problem::new(net, state, [0.0, 100.0], rates);
332 let (final_state, reached) = find_equilibrium(&prob);
333
334 assert!(reached, "SIR should reach equilibrium");
335
336 let total = final_state["S"] + final_state["I"] + final_state["R"];
338 assert!(
339 (total - 1000.0).abs() < 1.0,
340 "Total should be conserved: got {}",
341 total
342 );
343
344 assert!(
346 final_state["I"] < 1.0,
347 "I should be near 0 at equilibrium: got {}",
348 final_state["I"]
349 );
350
351 assert!(
353 final_state["R"] > 900.0,
354 "R should be >900 at equilibrium: got {}",
355 final_state["R"]
356 );
357 }
358
359 #[test]
360 fn test_is_equilibrium() {
361 let (net, rates) = PetriNet::build().sir(999.0, 1.0, 0.0).with_rates(1.0);
362
363 let state = net.set_state(None);
364 let prob = Problem::new(net, state, [0.0, 100.0], rates);
365
366 assert!(!is_equilibrium(&prob, &prob.u0, 1e-6));
368
369 let (eq_state, reached) = find_equilibrium(&prob);
371 if reached {
372 assert!(is_equilibrium(&prob, &eq_state, 1e-4));
373 }
374 }
375}