Skip to main content

pflow_solver/
ode.rs

1//! ODE solver core: Problem, Solution, Options, Solve, mass-action kinetics.
2
3use std::collections::HashMap;
4
5use pflow_core::net::PetriNet;
6use pflow_core::State;
7
8use crate::methods::Solver;
9
10/// A function that computes the derivative du/dt given time t and state u.
11pub type ODEFunc = Box<dyn Fn(f64, &State) -> State>;
12
13// Internal vectorized ODE function: (t, u) -> du using dense arrays.
14type VecF = Box<dyn Fn(f64, &[f64]) -> Vec<f64>>;
15
16/// An ODE initial value problem for a Petri net.
17pub struct Problem {
18    pub net: PetriNet,
19    pub u0: State,
20    pub tspan: [f64; 2],
21    pub rates: HashMap<String, f64>,
22    pub f: ODEFunc,
23    pub state_labels: Vec<String>,
24    // Vectorized internals for fast solve()
25    #[allow(dead_code)]
26    state_index: HashMap<String, usize>,
27    vec_u0: Vec<f64>,
28    vec_f: VecF,
29}
30
31impl Problem {
32    /// Creates a new ODE problem from a Petri net.
33    pub fn new(
34        net: PetriNet,
35        initial_state: State,
36        tspan: [f64; 2],
37        rates: HashMap<String, f64>,
38    ) -> Self {
39        let f = build_ode_function(&net, &rates);
40        let state_labels: Vec<String> = initial_state.keys().cloned().collect();
41        let state_index: HashMap<String, usize> = state_labels
42            .iter()
43            .enumerate()
44            .map(|(i, label)| (label.clone(), i))
45            .collect();
46        let vec_u0: Vec<f64> = state_labels
47            .iter()
48            .map(|label| initial_state.get(label).copied().unwrap_or(0.0))
49            .collect();
50        let n_places = state_labels.len();
51        let vec_f = build_vec_ode_function(&net, &rates, &state_index, n_places);
52        Self {
53            net,
54            u0: initial_state,
55            tspan,
56            rates,
57            f,
58            state_labels,
59            state_index,
60            vec_u0,
61            vec_f,
62        }
63    }
64}
65
66/// Constructs the ODE derivative function for a Petri net using mass-action kinetics.
67fn build_ode_function(net: &PetriNet, rates: &HashMap<String, f64>) -> ODEFunc {
68    // Pre-compute structure for the closure
69    let place_labels: Vec<String> = net.places.keys().cloned().collect();
70    let trans_labels: Vec<String> = net.transitions.keys().cloned().collect();
71    let arcs: Vec<(String, String, f64)> = net
72        .arcs
73        .iter()
74        .map(|a| (a.source.clone(), a.target.clone(), a.weight_sum()))
75        .collect();
76    let place_set: std::collections::HashSet<String> = net.places.keys().cloned().collect();
77    let rates = rates.clone();
78
79    Box::new(move |_t: f64, u: &State| -> State {
80        let mut du: State = place_labels.iter().map(|l| (l.clone(), 0.0)).collect();
81
82        for trans_label in &trans_labels {
83            let rate = rates.get(trans_label).copied().unwrap_or(1.0);
84            let mut flux = rate;
85
86            // Compute flux using mass-action kinetics
87            for (source, target, _weight) in &arcs {
88                if target == trans_label && place_set.contains(source) {
89                    let place_state = u.get(source).copied().unwrap_or(0.0);
90                    if place_state <= 0.0 {
91                        flux = 0.0;
92                        break;
93                    }
94                    flux *= place_state;
95                }
96            }
97
98            // Apply flux to connected places
99            if flux > 0.0 {
100                for (source, target, weight) in &arcs {
101                    if target == trans_label && place_set.contains(source) {
102                        // Input arc: consume tokens
103                        if let Some(v) = du.get_mut(source) {
104                            *v -= flux * weight;
105                        }
106                    } else if source == trans_label && place_set.contains(target) {
107                        // Output arc: produce tokens
108                        if let Some(v) = du.get_mut(target) {
109                            *v += flux * weight;
110                        }
111                    }
112                }
113            }
114        }
115        du
116    })
117}
118
119/// Constructs a vectorized ODE derivative function with pre-indexed arcs.
120///
121/// This replaces HashMap lookups with array indexing and pre-groups arcs
122/// by transition, reducing per-call cost from O(T*A) to O(A).
123fn build_vec_ode_function(
124    net: &PetriNet,
125    rates: &HashMap<String, f64>,
126    state_index: &HashMap<String, usize>,
127    n_places: usize,
128) -> VecF {
129    // Pre-group arcs by transition: O(A) construction
130    let mut input_map: HashMap<&str, Vec<(usize, f64)>> = HashMap::new();
131    let mut output_map: HashMap<&str, Vec<(usize, f64)>> = HashMap::new();
132
133    for arc in &net.arcs {
134        let w = arc.weight_sum();
135        if net.transitions.contains_key(&arc.target) {
136            if let Some(&idx) = state_index.get(&arc.source) {
137                input_map
138                    .entry(arc.target.as_str())
139                    .or_default()
140                    .push((idx, w));
141            }
142        }
143        if net.transitions.contains_key(&arc.source) {
144            if let Some(&idx) = state_index.get(&arc.target) {
145                output_map
146                    .entry(arc.source.as_str())
147                    .or_default()
148                    .push((idx, w));
149            }
150        }
151    }
152
153    // Build compact transition table: (rate, inputs, outputs)
154    let transitions: Vec<(f64, Vec<(usize, f64)>, Vec<(usize, f64)>)> = net
155        .transitions
156        .keys()
157        .map(|label| {
158            let rate = rates.get(label).copied().unwrap_or(1.0);
159            let inputs = input_map.remove(label.as_str()).unwrap_or_default();
160            let outputs = output_map.remove(label.as_str()).unwrap_or_default();
161            (rate, inputs, outputs)
162        })
163        .collect();
164
165    Box::new(move |_t: f64, u: &[f64]| -> Vec<f64> {
166        let mut du = vec![0.0; n_places];
167
168        for (rate, inputs, outputs) in &transitions {
169            let mut flux = *rate;
170
171            // Mass-action kinetics: flux = rate * product(input tokens)
172            for &(idx, _w) in inputs {
173                let v = u[idx];
174                if v <= 0.0 {
175                    flux = 0.0;
176                    break;
177                }
178                flux *= v;
179            }
180
181            if flux > 0.0 {
182                for &(idx, w) in inputs {
183                    du[idx] -= flux * w;
184                }
185                for &(idx, w) in outputs {
186                    du[idx] += flux * w;
187                }
188            }
189        }
190
191        du
192    })
193}
194
195/// The solution to an ODE problem.
196pub struct Solution {
197    pub t: Vec<f64>,
198    pub u: Vec<State>,
199    pub state_labels: Vec<String>,
200}
201
202impl Solution {
203    /// Extracts the time series for a specific state variable by label.
204    pub fn get_variable(&self, label: &str) -> Vec<f64> {
205        self.u
206            .iter()
207            .map(|s| s.get(label).copied().unwrap_or(0.0))
208            .collect()
209    }
210
211    /// Returns the final state of the system.
212    pub fn get_final_state(&self) -> Option<&State> {
213        self.u.last()
214    }
215
216    /// Returns the state at a specific time point index.
217    pub fn get_state(&self, i: usize) -> Option<&State> {
218        self.u.get(i)
219    }
220}
221
222/// Solver configuration parameters.
223#[derive(Debug, Clone)]
224pub struct Options {
225    pub dt: f64,
226    pub dtmin: f64,
227    pub dtmax: f64,
228    pub abstol: f64,
229    pub reltol: f64,
230    pub maxiters: usize,
231    pub adaptive: bool,
232}
233
234impl Options {
235    /// Default solver options — balanced for most problems.
236    pub fn default_opts() -> Self {
237        Self {
238            dt: 0.01,
239            dtmin: 1e-6,
240            dtmax: 0.1,
241            abstol: 1e-6,
242            reltol: 1e-3,
243            maxiters: 100_000,
244            adaptive: true,
245        }
246    }
247
248    /// Options that match the pflow.xyz JavaScript solver.
249    pub fn js_parity() -> Self {
250        Self {
251            dt: 0.01,
252            dtmin: 1e-6,
253            dtmax: 1.0,
254            abstol: 1e-6,
255            reltol: 1e-3,
256            maxiters: 100_000,
257            adaptive: true,
258        }
259    }
260
261    /// Fast options: speed over accuracy (~10x faster).
262    pub fn fast() -> Self {
263        Self {
264            dt: 0.1,
265            dtmin: 1e-4,
266            dtmax: 1.0,
267            abstol: 1e-2,
268            reltol: 1e-2,
269            maxiters: 1_000,
270            adaptive: true,
271        }
272    }
273
274    /// Accurate options: high precision.
275    pub fn accurate() -> Self {
276        Self {
277            dt: 0.001,
278            dtmin: 1e-8,
279            dtmax: 0.1,
280            abstol: 1e-9,
281            reltol: 1e-6,
282            maxiters: 1_000_000,
283            adaptive: true,
284        }
285    }
286
287    /// Options for stiff ODE systems.
288    pub fn stiff() -> Self {
289        Self {
290            dt: 0.001,
291            dtmin: 1e-10,
292            dtmax: 0.01,
293            abstol: 1e-8,
294            reltol: 1e-5,
295            maxiters: 500_000,
296            adaptive: true,
297        }
298    }
299
300    /// Game AI options: fast move evaluation.
301    pub fn game_ai() -> Self {
302        Self {
303            dt: 0.1,
304            dtmin: 1e-3,
305            dtmax: 1.0,
306            abstol: 1e-2,
307            reltol: 1e-2,
308            maxiters: 500,
309            adaptive: true,
310        }
311    }
312
313    /// Epidemic/population modeling options.
314    pub fn epidemic() -> Self {
315        Self {
316            dt: 0.01,
317            dtmin: 1e-6,
318            dtmax: 0.5,
319            abstol: 1e-6,
320            reltol: 1e-4,
321            maxiters: 200_000,
322            adaptive: true,
323        }
324    }
325
326    /// Workflow/process simulation options.
327    pub fn workflow() -> Self {
328        Self {
329            dt: 0.1,
330            dtmin: 1e-4,
331            dtmax: 10.0,
332            abstol: 1e-4,
333            reltol: 1e-3,
334            maxiters: 50_000,
335            adaptive: true,
336        }
337    }
338
339    /// Long-run simulation options.
340    pub fn long_run() -> Self {
341        Self {
342            dt: 0.1,
343            dtmin: 1e-4,
344            dtmax: 10.0,
345            abstol: 1e-5,
346            reltol: 1e-3,
347            maxiters: 500_000,
348            adaptive: true,
349        }
350    }
351}
352
353/// Copies a state map.
354pub fn copy_state(s: &State) -> State {
355    s.clone()
356}
357
358/// Converts a dense vector back to a labeled State map.
359fn vec_to_state(v: &[f64], labels: &[String]) -> State {
360    labels
361        .iter()
362        .enumerate()
363        .map(|(i, label)| (label.clone(), v[i]))
364        .collect()
365}
366
367/// Integrates the ODE problem using the given solver and options.
368///
369/// Internally uses vectorized (dense array) state representation for performance.
370pub fn solve(prob: &Problem, solver: &Solver, opts: &Options) -> Solution {
371    let dt = opts.dt;
372    let dtmin = opts.dtmin;
373    let dtmax = opts.dtmax;
374    let abstol = opts.abstol;
375    let reltol = opts.reltol;
376    let maxiters = opts.maxiters;
377    let adaptive = opts.adaptive;
378
379    let t0 = prob.tspan[0];
380    let tf = prob.tspan[1];
381    let f = &prob.vec_f;
382    let n = prob.vec_u0.len();
383
384    let mut t_out = vec![t0];
385    let mut u_out: Vec<Vec<f64>> = vec![prob.vec_u0.clone()];
386    let mut tcur = t0;
387    let mut ucur = prob.vec_u0.clone();
388    let mut dtcur = dt;
389    let mut nsteps = 0usize;
390
391    while tcur < tf && nsteps < maxiters {
392        // Don't overshoot
393        if tcur + dtcur > tf {
394            dtcur = tf - tcur;
395        }
396
397        // Compute Runge-Kutta stages
398        let num_stages = solver.c.len();
399        let mut k: Vec<Vec<f64>> = Vec::with_capacity(num_stages);
400        k.push(f(tcur, &ucur));
401
402        for stage in 1..num_stages {
403            let tstage = tcur + solver.c[stage] * dtcur;
404            let mut ustage = ucur.clone();
405            for j in 0..stage {
406                let aj = if stage < solver.a.len() && j < solver.a[stage].len() {
407                    solver.a[stage][j]
408                } else {
409                    0.0
410                };
411                if aj != 0.0 {
412                    let scale = dtcur * aj;
413                    for i in 0..n {
414                        ustage[i] += scale * k[j][i];
415                    }
416                }
417            }
418            k.push(f(tstage, &ustage));
419        }
420
421        // Compute solution at next step
422        let mut unext = ucur.clone();
423        for j in 0..solver.b.len() {
424            if solver.b[j] != 0.0 {
425                let scale = dtcur * solver.b[j];
426                for i in 0..n {
427                    unext[i] += scale * k[j][i];
428                }
429            }
430        }
431
432        // Compute error estimate
433        let mut err = 0.0;
434        if adaptive {
435            for i in 0..n {
436                let mut errest = 0.0;
437                for j in 0..solver.b_hat.len() {
438                    errest += dtcur * solver.b_hat[j] * k[j][i];
439                }
440                let uc = ucur[i];
441                let un = unext[i];
442                let mut scale = abstol + reltol * uc.abs().max(un.abs());
443                if scale == 0.0 {
444                    scale = abstol;
445                }
446                let val = errest.abs() / scale;
447                if val > err {
448                    err = val;
449                }
450            }
451        }
452
453        // Accept or reject step
454        if !adaptive || err <= 1.0 || dtcur <= dtmin {
455            tcur += dtcur;
456            ucur = unext;
457            t_out.push(tcur);
458            u_out.push(ucur.clone());
459            nsteps += 1;
460
461            if adaptive && err > 0.0 {
462                let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
463                let factor = factor.min(5.0);
464                dtcur = dtmax.min(dtmin.max(dtcur * factor));
465            }
466        } else {
467            let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
468            let factor = factor.max(0.1);
469            dtcur = dtmin.max(dtcur * factor);
470        }
471    }
472
473    // Convert dense trajectory to State maps for backward compatibility
474    let state_u: Vec<State> = u_out
475        .iter()
476        .map(|v| vec_to_state(v, &prob.state_labels))
477        .collect();
478
479    Solution {
480        t: t_out,
481        u: state_u,
482        state_labels: prob.state_labels.clone(),
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use crate::methods;
490
491    #[test]
492    fn test_simple_decay() {
493        // A -> t1 -> B, should transfer tokens from A to B
494        let net = PetriNet::build()
495            .place("A", 10.0)
496            .place("B", 0.0)
497            .transition("t1")
498            .arc("A", "t1", 1.0)
499            .arc("t1", "B", 1.0)
500            .done();
501
502        let state = net.set_state(None);
503        let rates = net.set_rates(None);
504        let prob = Problem::new(net, state, [0.0, 10.0], rates);
505        let sol = solve(&prob, &methods::tsit5(), &Options::default_opts());
506
507        let final_state = sol.get_final_state().unwrap();
508        let total = final_state["A"] + final_state["B"];
509        // Conservation: A + B should be approximately 10
510        assert!((total - 10.0).abs() < 0.1);
511    }
512}