Skip to main content

pflow_solver/
ode.rs

1//! ODE solver core: Problem, Solution, Options, Solve, mass-action kinetics.
2
3use std::collections::HashMap;
4
5use pflow_core::net::PetriNet;
6use pflow_core::State;
7
8use crate::methods::Solver;
9
10/// A function that computes the derivative du/dt given time t and state u.
11pub type ODEFunc = Box<dyn Fn(f64, &State) -> State>;
12
13/// An ODE initial value problem for a Petri net.
14pub struct Problem {
15    pub net: PetriNet,
16    pub u0: State,
17    pub tspan: [f64; 2],
18    pub rates: HashMap<String, f64>,
19    pub f: ODEFunc,
20    pub state_labels: Vec<String>,
21}
22
23impl Problem {
24    /// Creates a new ODE problem from a Petri net.
25    pub fn new(
26        net: PetriNet,
27        initial_state: State,
28        tspan: [f64; 2],
29        rates: HashMap<String, f64>,
30    ) -> Self {
31        let f = build_ode_function(&net, &rates);
32        let state_labels: Vec<String> = initial_state.keys().cloned().collect();
33        Self {
34            net,
35            u0: initial_state,
36            tspan,
37            rates,
38            f,
39            state_labels,
40        }
41    }
42}
43
44/// Constructs the ODE derivative function for a Petri net using mass-action kinetics.
45fn build_ode_function(net: &PetriNet, rates: &HashMap<String, f64>) -> ODEFunc {
46    // Pre-compute structure for the closure
47    let place_labels: Vec<String> = net.places.keys().cloned().collect();
48    let trans_labels: Vec<String> = net.transitions.keys().cloned().collect();
49    let arcs: Vec<(String, String, f64)> = net
50        .arcs
51        .iter()
52        .map(|a| (a.source.clone(), a.target.clone(), a.weight_sum()))
53        .collect();
54    let place_set: std::collections::HashSet<String> = net.places.keys().cloned().collect();
55    let rates = rates.clone();
56
57    Box::new(move |_t: f64, u: &State| -> State {
58        let mut du: State = place_labels.iter().map(|l| (l.clone(), 0.0)).collect();
59
60        for trans_label in &trans_labels {
61            let rate = rates.get(trans_label).copied().unwrap_or(1.0);
62            let mut flux = rate;
63
64            // Compute flux using mass-action kinetics
65            for (source, target, _weight) in &arcs {
66                if target == trans_label && place_set.contains(source) {
67                    let place_state = u.get(source).copied().unwrap_or(0.0);
68                    if place_state <= 0.0 {
69                        flux = 0.0;
70                        break;
71                    }
72                    flux *= place_state;
73                }
74            }
75
76            // Apply flux to connected places
77            if flux > 0.0 {
78                for (source, target, weight) in &arcs {
79                    if target == trans_label && place_set.contains(source) {
80                        // Input arc: consume tokens
81                        if let Some(v) = du.get_mut(source) {
82                            *v -= flux * weight;
83                        }
84                    } else if source == trans_label && place_set.contains(target) {
85                        // Output arc: produce tokens
86                        if let Some(v) = du.get_mut(target) {
87                            *v += flux * weight;
88                        }
89                    }
90                }
91            }
92        }
93        du
94    })
95}
96
97/// The solution to an ODE problem.
98pub struct Solution {
99    pub t: Vec<f64>,
100    pub u: Vec<State>,
101    pub state_labels: Vec<String>,
102}
103
104impl Solution {
105    /// Extracts the time series for a specific state variable by label.
106    pub fn get_variable(&self, label: &str) -> Vec<f64> {
107        self.u
108            .iter()
109            .map(|s| s.get(label).copied().unwrap_or(0.0))
110            .collect()
111    }
112
113    /// Returns the final state of the system.
114    pub fn get_final_state(&self) -> Option<&State> {
115        self.u.last()
116    }
117
118    /// Returns the state at a specific time point index.
119    pub fn get_state(&self, i: usize) -> Option<&State> {
120        self.u.get(i)
121    }
122}
123
124/// Solver configuration parameters.
125#[derive(Debug, Clone)]
126pub struct Options {
127    pub dt: f64,
128    pub dtmin: f64,
129    pub dtmax: f64,
130    pub abstol: f64,
131    pub reltol: f64,
132    pub maxiters: usize,
133    pub adaptive: bool,
134}
135
136impl Options {
137    /// Default solver options — balanced for most problems.
138    pub fn default_opts() -> Self {
139        Self {
140            dt: 0.01,
141            dtmin: 1e-6,
142            dtmax: 0.1,
143            abstol: 1e-6,
144            reltol: 1e-3,
145            maxiters: 100_000,
146            adaptive: true,
147        }
148    }
149
150    /// Options that match the pflow.xyz JavaScript solver.
151    pub fn js_parity() -> Self {
152        Self {
153            dt: 0.01,
154            dtmin: 1e-6,
155            dtmax: 1.0,
156            abstol: 1e-6,
157            reltol: 1e-3,
158            maxiters: 100_000,
159            adaptive: true,
160        }
161    }
162
163    /// Fast options: speed over accuracy (~10x faster).
164    pub fn fast() -> Self {
165        Self {
166            dt: 0.1,
167            dtmin: 1e-4,
168            dtmax: 1.0,
169            abstol: 1e-2,
170            reltol: 1e-2,
171            maxiters: 1_000,
172            adaptive: true,
173        }
174    }
175
176    /// Accurate options: high precision.
177    pub fn accurate() -> Self {
178        Self {
179            dt: 0.001,
180            dtmin: 1e-8,
181            dtmax: 0.1,
182            abstol: 1e-9,
183            reltol: 1e-6,
184            maxiters: 1_000_000,
185            adaptive: true,
186        }
187    }
188
189    /// Options for stiff ODE systems.
190    pub fn stiff() -> Self {
191        Self {
192            dt: 0.001,
193            dtmin: 1e-10,
194            dtmax: 0.01,
195            abstol: 1e-8,
196            reltol: 1e-5,
197            maxiters: 500_000,
198            adaptive: true,
199        }
200    }
201
202    /// Game AI options: fast move evaluation.
203    pub fn game_ai() -> Self {
204        Self {
205            dt: 0.1,
206            dtmin: 1e-3,
207            dtmax: 1.0,
208            abstol: 1e-2,
209            reltol: 1e-2,
210            maxiters: 500,
211            adaptive: true,
212        }
213    }
214
215    /// Epidemic/population modeling options.
216    pub fn epidemic() -> Self {
217        Self {
218            dt: 0.01,
219            dtmin: 1e-6,
220            dtmax: 0.5,
221            abstol: 1e-6,
222            reltol: 1e-4,
223            maxiters: 200_000,
224            adaptive: true,
225        }
226    }
227
228    /// Workflow/process simulation options.
229    pub fn workflow() -> Self {
230        Self {
231            dt: 0.1,
232            dtmin: 1e-4,
233            dtmax: 10.0,
234            abstol: 1e-4,
235            reltol: 1e-3,
236            maxiters: 50_000,
237            adaptive: true,
238        }
239    }
240
241    /// Long-run simulation options.
242    pub fn long_run() -> Self {
243        Self {
244            dt: 0.1,
245            dtmin: 1e-4,
246            dtmax: 10.0,
247            abstol: 1e-5,
248            reltol: 1e-3,
249            maxiters: 500_000,
250            adaptive: true,
251        }
252    }
253}
254
255/// Copies a state map.
256pub fn copy_state(s: &State) -> State {
257    s.clone()
258}
259
260/// Integrates the ODE problem using the given solver and options.
261pub fn solve(prob: &Problem, solver: &Solver, opts: &Options) -> Solution {
262    let dt = opts.dt;
263    let dtmin = opts.dtmin;
264    let dtmax = opts.dtmax;
265    let abstol = opts.abstol;
266    let reltol = opts.reltol;
267    let maxiters = opts.maxiters;
268    let adaptive = opts.adaptive;
269
270    let t0 = prob.tspan[0];
271    let tf = prob.tspan[1];
272    let f = &prob.f;
273    let state_labels = &prob.state_labels;
274
275    let mut t_out = vec![t0];
276    let mut u_out = vec![copy_state(&prob.u0)];
277    let mut tcur = t0;
278    let mut ucur = copy_state(&prob.u0);
279    let mut dtcur = dt;
280    let mut nsteps = 0usize;
281
282    while tcur < tf && nsteps < maxiters {
283        // Don't overshoot
284        if tcur + dtcur > tf {
285            dtcur = tf - tcur;
286        }
287
288        // Compute Runge-Kutta stages
289        let num_stages = solver.c.len();
290        let mut k: Vec<State> = Vec::with_capacity(num_stages);
291        k.push(f(tcur, &ucur));
292
293        for stage in 1..num_stages {
294            let tstage = tcur + solver.c[stage] * dtcur;
295            let mut ustage = copy_state(&ucur);
296            for key in state_labels {
297                for j in 0..stage {
298                    let aj = if stage < solver.a.len() && j < solver.a[stage].len() {
299                        solver.a[stage][j]
300                    } else {
301                        0.0
302                    };
303                    if let (Some(us), Some(kj)) = (ustage.get_mut(key), k[j].get(key)) {
304                        *us += dtcur * aj * kj;
305                    }
306                }
307            }
308            k.push(f(tstage, &ustage));
309        }
310
311        // Compute solution at next step
312        let mut unext = copy_state(&ucur);
313        for key in state_labels {
314            for j in 0..solver.b.len() {
315                if let (Some(un), Some(kj)) = (unext.get_mut(key), k[j].get(key)) {
316                    *un += dtcur * solver.b[j] * kj;
317                }
318            }
319        }
320
321        // Compute error estimate
322        let mut err = 0.0;
323        if adaptive {
324            for key in state_labels {
325                let mut errest = 0.0;
326                for j in 0..solver.b_hat.len() {
327                    if let Some(kj) = k[j].get(key) {
328                        errest += dtcur * solver.b_hat[j] * kj;
329                    }
330                }
331                let uc = ucur.get(key).copied().unwrap_or(0.0);
332                let un = unext.get(key).copied().unwrap_or(0.0);
333                let mut scale = abstol + reltol * uc.abs().max(un.abs());
334                if scale == 0.0 {
335                    scale = abstol;
336                }
337                let val = errest.abs() / scale;
338                if val > err {
339                    err = val;
340                }
341            }
342        }
343
344        // Accept or reject step
345        if !adaptive || err <= 1.0 || dtcur <= dtmin {
346            tcur += dtcur;
347            ucur = unext;
348            t_out.push(tcur);
349            u_out.push(copy_state(&ucur));
350            nsteps += 1;
351
352            if adaptive && err > 0.0 {
353                let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
354                let factor = factor.min(5.0);
355                dtcur = dtmax.min(dtmin.max(dtcur * factor));
356            }
357        } else {
358            let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
359            let factor = factor.max(0.1);
360            dtcur = dtmin.max(dtcur * factor);
361        }
362    }
363
364    Solution {
365        t: t_out,
366        u: u_out,
367        state_labels: state_labels.clone(),
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::methods;
375
376    #[test]
377    fn test_simple_decay() {
378        // A -> t1 -> B, should transfer tokens from A to B
379        let net = PetriNet::build()
380            .place("A", 10.0)
381            .place("B", 0.0)
382            .transition("t1")
383            .arc("A", "t1", 1.0)
384            .arc("t1", "B", 1.0)
385            .done();
386
387        let state = net.set_state(None);
388        let rates = net.set_rates(None);
389        let prob = Problem::new(net, state, [0.0, 10.0], rates);
390        let sol = solve(&prob, &methods::tsit5(), &Options::default_opts());
391
392        let final_state = sol.get_final_state().unwrap();
393        let total = final_state["A"] + final_state["B"];
394        // Conservation: A + B should be approximately 10
395        assert!((total - 10.0).abs() < 0.1);
396    }
397}