Skip to main content

pflow_solver/
equilibrium.rs

1//! Equilibrium detection for ODE solutions.
2
3use pflow_core::State;
4
5use crate::methods;
6use crate::ode::{copy_state, Options, Problem, Solution};
7
8/// Configuration for equilibrium detection.
9#[derive(Debug, Clone)]
10pub struct EquilibriumOptions {
11    pub tolerance: f64,
12    pub consecutive_steps: usize,
13    pub min_time: f64,
14    pub check_interval: usize,
15}
16
17impl EquilibriumOptions {
18    pub fn default_opts() -> Self {
19        Self {
20            tolerance: 1e-6,
21            consecutive_steps: 5,
22            min_time: 0.1,
23            check_interval: 10,
24        }
25    }
26
27    pub fn fast() -> Self {
28        Self {
29            tolerance: 1e-4,
30            consecutive_steps: 3,
31            min_time: 0.01,
32            check_interval: 5,
33        }
34    }
35
36    pub fn strict() -> Self {
37        Self {
38            tolerance: 1e-9,
39            consecutive_steps: 10,
40            min_time: 1.0,
41            check_interval: 1,
42        }
43    }
44}
45
46/// Result of equilibrium detection.
47#[derive(Debug, Clone)]
48pub struct EquilibriumResult {
49    pub reached: bool,
50    pub time: f64,
51    pub state: State,
52    pub max_change: f64,
53    pub steps: usize,
54    pub reason: String,
55}
56
57/// Integrates until equilibrium or time span exhausted.
58pub fn solve_until_equilibrium(
59    prob: &Problem,
60    solver: &methods::Solver,
61    opts: &Options,
62    eq_opts: &EquilibriumOptions,
63) -> (Solution, EquilibriumResult) {
64    let dt = opts.dt;
65    let dtmin = opts.dtmin;
66    let dtmax = opts.dtmax;
67    let abstol = opts.abstol;
68    let reltol = opts.reltol;
69    let maxiters = opts.maxiters;
70    let adaptive = opts.adaptive;
71
72    let t0 = prob.tspan[0];
73    let tf = prob.tspan[1];
74    let f = &prob.f;
75    let state_labels = &prob.state_labels;
76
77    let mut t_out = vec![t0];
78    let mut u_out = vec![copy_state(&prob.u0)];
79    let mut tcur = t0;
80    let mut ucur = copy_state(&prob.u0);
81    let mut dtcur = dt;
82    let mut nsteps = 0usize;
83    let mut consecutive_small = 0usize;
84    let mut check_counter = 0usize;
85
86    let mut eq_result = EquilibriumResult {
87        reached: false,
88        time: 0.0,
89        state: State::new(),
90        max_change: 0.0,
91        steps: 0,
92        reason: "time_exhausted".into(),
93    };
94
95    while tcur < tf && nsteps < maxiters {
96        if tcur + dtcur > tf {
97            dtcur = tf - tcur;
98        }
99
100        // Compute RK stages
101        let num_stages = solver.c.len();
102        let mut k: Vec<State> = Vec::with_capacity(num_stages);
103        k.push(f(tcur, &ucur));
104
105        for stage in 1..num_stages {
106            let tstage = tcur + solver.c[stage] * dtcur;
107            let mut ustage = copy_state(&ucur);
108            for key in state_labels {
109                for j in 0..stage {
110                    let aj = if stage < solver.a.len() && j < solver.a[stage].len() {
111                        solver.a[stage][j]
112                    } else {
113                        0.0
114                    };
115                    if let (Some(us), Some(kj)) = (ustage.get_mut(key), k[j].get(key)) {
116                        *us += dtcur * aj * kj;
117                    }
118                }
119            }
120            k.push(f(tstage, &ustage));
121        }
122
123        let mut unext = copy_state(&ucur);
124        for key in state_labels {
125            for j in 0..solver.b.len() {
126                if let (Some(un), Some(kj)) = (unext.get_mut(key), k[j].get(key)) {
127                    *un += dtcur * solver.b[j] * kj;
128                }
129            }
130        }
131
132        let mut err = 0.0;
133        if adaptive {
134            for key in state_labels {
135                let mut errest = 0.0;
136                for j in 0..solver.b_hat.len() {
137                    if let Some(kj) = k[j].get(key) {
138                        errest += dtcur * solver.b_hat[j] * kj;
139                    }
140                }
141                let uc = ucur.get(key).copied().unwrap_or(0.0);
142                let un = unext.get(key).copied().unwrap_or(0.0);
143                let mut scale = abstol + reltol * uc.abs().max(un.abs());
144                if scale == 0.0 {
145                    scale = abstol;
146                }
147                let val = errest.abs() / scale;
148                if val > err {
149                    err = val;
150                }
151            }
152        }
153
154        if !adaptive || err <= 1.0 || dtcur <= dtmin {
155            tcur += dtcur;
156            ucur = unext;
157            t_out.push(tcur);
158            u_out.push(copy_state(&ucur));
159            nsteps += 1;
160
161            // Check for equilibrium
162            check_counter += 1;
163            if tcur >= t0 + eq_opts.min_time
164                && (eq_opts.check_interval == 0 || check_counter >= eq_opts.check_interval)
165            {
166                check_counter = 0;
167                let max_change = compute_max_change(&k[0]);
168
169                if max_change < eq_opts.tolerance {
170                    consecutive_small += 1;
171                    if consecutive_small >= eq_opts.consecutive_steps {
172                        eq_result.reached = true;
173                        eq_result.time = tcur;
174                        eq_result.state = copy_state(&ucur);
175                        eq_result.max_change = max_change;
176                        eq_result.steps = nsteps;
177                        eq_result.reason = "equilibrium_reached".into();
178                        break;
179                    }
180                } else {
181                    consecutive_small = 0;
182                }
183            }
184
185            if adaptive && err > 0.0 {
186                let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
187                let factor = factor.min(5.0);
188                dtcur = dtmax.min(dtmin.max(dtcur * factor));
189            }
190        } else {
191            let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
192            let factor = factor.max(0.1);
193            dtcur = dtmin.max(dtcur * factor);
194        }
195    }
196
197    if nsteps >= maxiters {
198        eq_result.reason = "max_iterations".into();
199    }
200
201    eq_result.steps = nsteps;
202    if !eq_result.reached {
203        eq_result.time = tcur;
204        eq_result.state = copy_state(&ucur);
205        if !u_out.is_empty() {
206            let du = f(tcur, &ucur);
207            eq_result.max_change = compute_max_change(&du);
208        }
209    }
210
211    let sol = Solution {
212        t: t_out,
213        u: u_out,
214        state_labels: state_labels.clone(),
215    };
216
217    (sol, eq_result)
218}
219
220fn compute_max_change(du: &State) -> f64 {
221    du.values().map(|v| v.abs()).fold(0.0f64, f64::max)
222}
223
224/// Checks if a state is at equilibrium for the given problem.
225pub fn is_equilibrium(prob: &Problem, state: &State, tolerance: f64) -> bool {
226    let du = (prob.f)(0.0, state);
227    compute_max_change(&du) < tolerance
228}
229
230/// Solves until equilibrium and returns just the final state.
231pub fn find_equilibrium(prob: &Problem) -> (State, bool) {
232    let (_, result) = solve_until_equilibrium(
233        prob,
234        &methods::tsit5(),
235        &Options::default_opts(),
236        &EquilibriumOptions::default_opts(),
237    );
238    (result.state, result.reached)
239}
240
241/// Fast equilibrium detection with aggressive settings.
242pub fn find_equilibrium_fast(prob: &Problem) -> (State, bool) {
243    let (sol, result) = solve_until_equilibrium(
244        prob,
245        &methods::tsit5(),
246        &Options::fast(),
247        &EquilibriumOptions::fast(),
248    );
249    if result.reached {
250        (result.state, true)
251    } else {
252        (
253            sol.get_final_state().cloned().unwrap_or_default(),
254            false,
255        )
256    }
257}
258
259/// Strict equilibrium detection with high confidence.
260pub fn find_equilibrium_accurate(prob: &Problem) -> (State, bool) {
261    let (_, result) = solve_until_equilibrium(
262        prob,
263        &methods::tsit5(),
264        &Options::accurate(),
265        &EquilibriumOptions::strict(),
266    );
267    (result.state, result.reached)
268}
269
270/// Combines solver and equilibrium options for specific use cases.
271#[derive(Debug, Clone)]
272pub struct OptionPair {
273    pub solver: Options,
274    pub equilibrium: EquilibriumOptions,
275}
276
277impl OptionPair {
278    /// Game AI options: fast evaluation with loose equilibrium detection.
279    pub fn game_ai() -> Self {
280        Self {
281            solver: Options::game_ai(),
282            equilibrium: EquilibriumOptions {
283                tolerance: 1e-3,
284                consecutive_steps: 2,
285                min_time: 0.01,
286                check_interval: 3,
287            },
288        }
289    }
290
291    /// Epidemic modeling options.
292    pub fn epidemic() -> Self {
293        Self {
294            solver: Options::epidemic(),
295            equilibrium: EquilibriumOptions::default_opts(),
296        }
297    }
298
299    /// Workflow/process simulation options.
300    pub fn workflow() -> Self {
301        Self {
302            solver: Options::workflow(),
303            equilibrium: EquilibriumOptions {
304                tolerance: 1e-4,
305                consecutive_steps: 3,
306                min_time: 0.5,
307                check_interval: 5,
308            },
309        }
310    }
311
312    /// Extended equilibrium analysis options.
313    pub fn long_run() -> Self {
314        Self {
315            solver: Options::long_run(),
316            equilibrium: EquilibriumOptions::strict(),
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use pflow_core::PetriNet;
325
326    #[test]
327    fn test_sir_equilibrium() {
328        let (net, rates) = PetriNet::build().sir(999.0, 1.0, 0.0).with_rates(1.0);
329
330        let state = net.set_state(None);
331        let prob = Problem::new(net, state, [0.0, 100.0], rates);
332        let (final_state, reached) = find_equilibrium(&prob);
333
334        assert!(reached, "SIR should reach equilibrium");
335
336        // Conservation: S + I + R = 1000
337        let total = final_state["S"] + final_state["I"] + final_state["R"];
338        assert!(
339            (total - 1000.0).abs() < 1.0,
340            "Total should be conserved: got {}",
341            total
342        );
343
344        // At equilibrium, I should be near 0
345        assert!(
346            final_state["I"] < 1.0,
347            "I should be near 0 at equilibrium: got {}",
348            final_state["I"]
349        );
350
351        // R should have most of the population
352        assert!(
353            final_state["R"] > 900.0,
354            "R should be >900 at equilibrium: got {}",
355            final_state["R"]
356        );
357    }
358
359    #[test]
360    fn test_is_equilibrium() {
361        let (net, rates) = PetriNet::build().sir(999.0, 1.0, 0.0).with_rates(1.0);
362
363        let state = net.set_state(None);
364        let prob = Problem::new(net, state, [0.0, 100.0], rates);
365
366        // Initial state should NOT be at equilibrium
367        assert!(!is_equilibrium(&prob, &prob.u0, 1e-6));
368
369        // Find equilibrium and verify
370        let (eq_state, reached) = find_equilibrium(&prob);
371        if reached {
372            assert!(is_equilibrium(&prob, &eq_state, 1e-4));
373        }
374    }
375}