1use std::collections::HashMap;
4
5use pflow_core::net::PetriNet;
6use pflow_core::State;
7
8use crate::methods::Solver;
9
10pub type ODEFunc = Box<dyn Fn(f64, &State) -> State>;
12
13pub struct Problem {
15 pub net: PetriNet,
16 pub u0: State,
17 pub tspan: [f64; 2],
18 pub rates: HashMap<String, f64>,
19 pub f: ODEFunc,
20 pub state_labels: Vec<String>,
21}
22
23impl Problem {
24 pub fn new(
26 net: PetriNet,
27 initial_state: State,
28 tspan: [f64; 2],
29 rates: HashMap<String, f64>,
30 ) -> Self {
31 let f = build_ode_function(&net, &rates);
32 let state_labels: Vec<String> = initial_state.keys().cloned().collect();
33 Self {
34 net,
35 u0: initial_state,
36 tspan,
37 rates,
38 f,
39 state_labels,
40 }
41 }
42}
43
44fn build_ode_function(net: &PetriNet, rates: &HashMap<String, f64>) -> ODEFunc {
46 let place_labels: Vec<String> = net.places.keys().cloned().collect();
48 let trans_labels: Vec<String> = net.transitions.keys().cloned().collect();
49 let arcs: Vec<(String, String, f64)> = net
50 .arcs
51 .iter()
52 .map(|a| (a.source.clone(), a.target.clone(), a.weight_sum()))
53 .collect();
54 let place_set: std::collections::HashSet<String> = net.places.keys().cloned().collect();
55 let rates = rates.clone();
56
57 Box::new(move |_t: f64, u: &State| -> State {
58 let mut du: State = place_labels.iter().map(|l| (l.clone(), 0.0)).collect();
59
60 for trans_label in &trans_labels {
61 let rate = rates.get(trans_label).copied().unwrap_or(1.0);
62 let mut flux = rate;
63
64 for (source, target, _weight) in &arcs {
66 if target == trans_label && place_set.contains(source) {
67 let place_state = u.get(source).copied().unwrap_or(0.0);
68 if place_state <= 0.0 {
69 flux = 0.0;
70 break;
71 }
72 flux *= place_state;
73 }
74 }
75
76 if flux > 0.0 {
78 for (source, target, weight) in &arcs {
79 if target == trans_label && place_set.contains(source) {
80 if let Some(v) = du.get_mut(source) {
82 *v -= flux * weight;
83 }
84 } else if source == trans_label && place_set.contains(target) {
85 if let Some(v) = du.get_mut(target) {
87 *v += flux * weight;
88 }
89 }
90 }
91 }
92 }
93 du
94 })
95}
96
97pub struct Solution {
99 pub t: Vec<f64>,
100 pub u: Vec<State>,
101 pub state_labels: Vec<String>,
102}
103
104impl Solution {
105 pub fn get_variable(&self, label: &str) -> Vec<f64> {
107 self.u
108 .iter()
109 .map(|s| s.get(label).copied().unwrap_or(0.0))
110 .collect()
111 }
112
113 pub fn get_final_state(&self) -> Option<&State> {
115 self.u.last()
116 }
117
118 pub fn get_state(&self, i: usize) -> Option<&State> {
120 self.u.get(i)
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct Options {
127 pub dt: f64,
128 pub dtmin: f64,
129 pub dtmax: f64,
130 pub abstol: f64,
131 pub reltol: f64,
132 pub maxiters: usize,
133 pub adaptive: bool,
134}
135
136impl Options {
137 pub fn default_opts() -> Self {
139 Self {
140 dt: 0.01,
141 dtmin: 1e-6,
142 dtmax: 0.1,
143 abstol: 1e-6,
144 reltol: 1e-3,
145 maxiters: 100_000,
146 adaptive: true,
147 }
148 }
149
150 pub fn js_parity() -> Self {
152 Self {
153 dt: 0.01,
154 dtmin: 1e-6,
155 dtmax: 1.0,
156 abstol: 1e-6,
157 reltol: 1e-3,
158 maxiters: 100_000,
159 adaptive: true,
160 }
161 }
162
163 pub fn fast() -> Self {
165 Self {
166 dt: 0.1,
167 dtmin: 1e-4,
168 dtmax: 1.0,
169 abstol: 1e-2,
170 reltol: 1e-2,
171 maxiters: 1_000,
172 adaptive: true,
173 }
174 }
175
176 pub fn accurate() -> Self {
178 Self {
179 dt: 0.001,
180 dtmin: 1e-8,
181 dtmax: 0.1,
182 abstol: 1e-9,
183 reltol: 1e-6,
184 maxiters: 1_000_000,
185 adaptive: true,
186 }
187 }
188
189 pub fn stiff() -> Self {
191 Self {
192 dt: 0.001,
193 dtmin: 1e-10,
194 dtmax: 0.01,
195 abstol: 1e-8,
196 reltol: 1e-5,
197 maxiters: 500_000,
198 adaptive: true,
199 }
200 }
201
202 pub fn game_ai() -> Self {
204 Self {
205 dt: 0.1,
206 dtmin: 1e-3,
207 dtmax: 1.0,
208 abstol: 1e-2,
209 reltol: 1e-2,
210 maxiters: 500,
211 adaptive: true,
212 }
213 }
214
215 pub fn epidemic() -> Self {
217 Self {
218 dt: 0.01,
219 dtmin: 1e-6,
220 dtmax: 0.5,
221 abstol: 1e-6,
222 reltol: 1e-4,
223 maxiters: 200_000,
224 adaptive: true,
225 }
226 }
227
228 pub fn workflow() -> Self {
230 Self {
231 dt: 0.1,
232 dtmin: 1e-4,
233 dtmax: 10.0,
234 abstol: 1e-4,
235 reltol: 1e-3,
236 maxiters: 50_000,
237 adaptive: true,
238 }
239 }
240
241 pub fn long_run() -> Self {
243 Self {
244 dt: 0.1,
245 dtmin: 1e-4,
246 dtmax: 10.0,
247 abstol: 1e-5,
248 reltol: 1e-3,
249 maxiters: 500_000,
250 adaptive: true,
251 }
252 }
253}
254
255pub fn copy_state(s: &State) -> State {
257 s.clone()
258}
259
260pub fn solve(prob: &Problem, solver: &Solver, opts: &Options) -> Solution {
262 let dt = opts.dt;
263 let dtmin = opts.dtmin;
264 let dtmax = opts.dtmax;
265 let abstol = opts.abstol;
266 let reltol = opts.reltol;
267 let maxiters = opts.maxiters;
268 let adaptive = opts.adaptive;
269
270 let t0 = prob.tspan[0];
271 let tf = prob.tspan[1];
272 let f = &prob.f;
273 let state_labels = &prob.state_labels;
274
275 let mut t_out = vec![t0];
276 let mut u_out = vec![copy_state(&prob.u0)];
277 let mut tcur = t0;
278 let mut ucur = copy_state(&prob.u0);
279 let mut dtcur = dt;
280 let mut nsteps = 0usize;
281
282 while tcur < tf && nsteps < maxiters {
283 if tcur + dtcur > tf {
285 dtcur = tf - tcur;
286 }
287
288 let num_stages = solver.c.len();
290 let mut k: Vec<State> = Vec::with_capacity(num_stages);
291 k.push(f(tcur, &ucur));
292
293 for stage in 1..num_stages {
294 let tstage = tcur + solver.c[stage] * dtcur;
295 let mut ustage = copy_state(&ucur);
296 for key in state_labels {
297 for j in 0..stage {
298 let aj = if stage < solver.a.len() && j < solver.a[stage].len() {
299 solver.a[stage][j]
300 } else {
301 0.0
302 };
303 if let (Some(us), Some(kj)) = (ustage.get_mut(key), k[j].get(key)) {
304 *us += dtcur * aj * kj;
305 }
306 }
307 }
308 k.push(f(tstage, &ustage));
309 }
310
311 let mut unext = copy_state(&ucur);
313 for key in state_labels {
314 for j in 0..solver.b.len() {
315 if let (Some(un), Some(kj)) = (unext.get_mut(key), k[j].get(key)) {
316 *un += dtcur * solver.b[j] * kj;
317 }
318 }
319 }
320
321 let mut err = 0.0;
323 if adaptive {
324 for key in state_labels {
325 let mut errest = 0.0;
326 for j in 0..solver.b_hat.len() {
327 if let Some(kj) = k[j].get(key) {
328 errest += dtcur * solver.b_hat[j] * kj;
329 }
330 }
331 let uc = ucur.get(key).copied().unwrap_or(0.0);
332 let un = unext.get(key).copied().unwrap_or(0.0);
333 let mut scale = abstol + reltol * uc.abs().max(un.abs());
334 if scale == 0.0 {
335 scale = abstol;
336 }
337 let val = errest.abs() / scale;
338 if val > err {
339 err = val;
340 }
341 }
342 }
343
344 if !adaptive || err <= 1.0 || dtcur <= dtmin {
346 tcur += dtcur;
347 ucur = unext;
348 t_out.push(tcur);
349 u_out.push(copy_state(&ucur));
350 nsteps += 1;
351
352 if adaptive && err > 0.0 {
353 let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
354 let factor = factor.min(5.0);
355 dtcur = dtmax.min(dtmin.max(dtcur * factor));
356 }
357 } else {
358 let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
359 let factor = factor.max(0.1);
360 dtcur = dtmin.max(dtcur * factor);
361 }
362 }
363
364 Solution {
365 t: t_out,
366 u: u_out,
367 state_labels: state_labels.clone(),
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::methods;
375
376 #[test]
377 fn test_simple_decay() {
378 let net = PetriNet::build()
380 .place("A", 10.0)
381 .place("B", 0.0)
382 .transition("t1")
383 .arc("A", "t1", 1.0)
384 .arc("t1", "B", 1.0)
385 .done();
386
387 let state = net.set_state(None);
388 let rates = net.set_rates(None);
389 let prob = Problem::new(net, state, [0.0, 10.0], rates);
390 let sol = solve(&prob, &methods::tsit5(), &Options::default_opts());
391
392 let final_state = sol.get_final_state().unwrap();
393 let total = final_state["A"] + final_state["B"];
394 assert!((total - 10.0).abs() < 0.1);
396 }
397}