1use std::collections::HashMap;
4
5use pflow_core::net::PetriNet;
6use pflow_core::State;
7
8use crate::methods::Solver;
9
10pub type ODEFunc = Box<dyn Fn(f64, &State) -> State>;
12
13type VecF = Box<dyn Fn(f64, &[f64]) -> Vec<f64>>;
15
16pub struct Problem {
18 pub net: PetriNet,
19 pub u0: State,
20 pub tspan: [f64; 2],
21 pub rates: HashMap<String, f64>,
22 pub f: ODEFunc,
23 pub state_labels: Vec<String>,
24 #[allow(dead_code)]
26 state_index: HashMap<String, usize>,
27 vec_u0: Vec<f64>,
28 vec_f: VecF,
29}
30
31impl Problem {
32 pub fn new(
34 net: PetriNet,
35 initial_state: State,
36 tspan: [f64; 2],
37 rates: HashMap<String, f64>,
38 ) -> Self {
39 let f = build_ode_function(&net, &rates);
40 let state_labels: Vec<String> = initial_state.keys().cloned().collect();
41 let state_index: HashMap<String, usize> = state_labels
42 .iter()
43 .enumerate()
44 .map(|(i, label)| (label.clone(), i))
45 .collect();
46 let vec_u0: Vec<f64> = state_labels
47 .iter()
48 .map(|label| initial_state.get(label).copied().unwrap_or(0.0))
49 .collect();
50 let n_places = state_labels.len();
51 let vec_f = build_vec_ode_function(&net, &rates, &state_index, n_places);
52 Self {
53 net,
54 u0: initial_state,
55 tspan,
56 rates,
57 f,
58 state_labels,
59 state_index,
60 vec_u0,
61 vec_f,
62 }
63 }
64}
65
66fn build_ode_function(net: &PetriNet, rates: &HashMap<String, f64>) -> ODEFunc {
68 let place_labels: Vec<String> = net.places.keys().cloned().collect();
70 let trans_labels: Vec<String> = net.transitions.keys().cloned().collect();
71 let arcs: Vec<(String, String, f64)> = net
72 .arcs
73 .iter()
74 .map(|a| (a.source.clone(), a.target.clone(), a.weight_sum()))
75 .collect();
76 let place_set: std::collections::HashSet<String> = net.places.keys().cloned().collect();
77 let rates = rates.clone();
78
79 Box::new(move |_t: f64, u: &State| -> State {
80 let mut du: State = place_labels.iter().map(|l| (l.clone(), 0.0)).collect();
81
82 for trans_label in &trans_labels {
83 let rate = rates.get(trans_label).copied().unwrap_or(1.0);
84 let mut flux = rate;
85
86 for (source, target, _weight) in &arcs {
88 if target == trans_label && place_set.contains(source) {
89 let place_state = u.get(source).copied().unwrap_or(0.0);
90 if place_state <= 0.0 {
91 flux = 0.0;
92 break;
93 }
94 flux *= place_state;
95 }
96 }
97
98 if flux > 0.0 {
100 for (source, target, weight) in &arcs {
101 if target == trans_label && place_set.contains(source) {
102 if let Some(v) = du.get_mut(source) {
104 *v -= flux * weight;
105 }
106 } else if source == trans_label && place_set.contains(target) {
107 if let Some(v) = du.get_mut(target) {
109 *v += flux * weight;
110 }
111 }
112 }
113 }
114 }
115 du
116 })
117}
118
119fn build_vec_ode_function(
124 net: &PetriNet,
125 rates: &HashMap<String, f64>,
126 state_index: &HashMap<String, usize>,
127 n_places: usize,
128) -> VecF {
129 let mut input_map: HashMap<&str, Vec<(usize, f64)>> = HashMap::new();
131 let mut output_map: HashMap<&str, Vec<(usize, f64)>> = HashMap::new();
132
133 for arc in &net.arcs {
134 let w = arc.weight_sum();
135 if net.transitions.contains_key(&arc.target) {
136 if let Some(&idx) = state_index.get(&arc.source) {
137 input_map
138 .entry(arc.target.as_str())
139 .or_default()
140 .push((idx, w));
141 }
142 }
143 if net.transitions.contains_key(&arc.source) {
144 if let Some(&idx) = state_index.get(&arc.target) {
145 output_map
146 .entry(arc.source.as_str())
147 .or_default()
148 .push((idx, w));
149 }
150 }
151 }
152
153 let transitions: Vec<(f64, Vec<(usize, f64)>, Vec<(usize, f64)>)> = net
155 .transitions
156 .keys()
157 .map(|label| {
158 let rate = rates.get(label).copied().unwrap_or(1.0);
159 let inputs = input_map.remove(label.as_str()).unwrap_or_default();
160 let outputs = output_map.remove(label.as_str()).unwrap_or_default();
161 (rate, inputs, outputs)
162 })
163 .collect();
164
165 Box::new(move |_t: f64, u: &[f64]| -> Vec<f64> {
166 let mut du = vec![0.0; n_places];
167
168 for (rate, inputs, outputs) in &transitions {
169 let mut flux = *rate;
170
171 for &(idx, _w) in inputs {
173 let v = u[idx];
174 if v <= 0.0 {
175 flux = 0.0;
176 break;
177 }
178 flux *= v;
179 }
180
181 if flux > 0.0 {
182 for &(idx, w) in inputs {
183 du[idx] -= flux * w;
184 }
185 for &(idx, w) in outputs {
186 du[idx] += flux * w;
187 }
188 }
189 }
190
191 du
192 })
193}
194
195pub struct Solution {
197 pub t: Vec<f64>,
198 pub u: Vec<State>,
199 pub state_labels: Vec<String>,
200}
201
202impl Solution {
203 pub fn get_variable(&self, label: &str) -> Vec<f64> {
205 self.u
206 .iter()
207 .map(|s| s.get(label).copied().unwrap_or(0.0))
208 .collect()
209 }
210
211 pub fn get_final_state(&self) -> Option<&State> {
213 self.u.last()
214 }
215
216 pub fn get_state(&self, i: usize) -> Option<&State> {
218 self.u.get(i)
219 }
220}
221
222#[derive(Debug, Clone)]
224pub struct Options {
225 pub dt: f64,
226 pub dtmin: f64,
227 pub dtmax: f64,
228 pub abstol: f64,
229 pub reltol: f64,
230 pub maxiters: usize,
231 pub adaptive: bool,
232}
233
234impl Options {
235 pub fn default_opts() -> Self {
237 Self {
238 dt: 0.01,
239 dtmin: 1e-6,
240 dtmax: 0.1,
241 abstol: 1e-6,
242 reltol: 1e-3,
243 maxiters: 100_000,
244 adaptive: true,
245 }
246 }
247
248 pub fn js_parity() -> Self {
250 Self {
251 dt: 0.01,
252 dtmin: 1e-6,
253 dtmax: 1.0,
254 abstol: 1e-6,
255 reltol: 1e-3,
256 maxiters: 100_000,
257 adaptive: true,
258 }
259 }
260
261 pub fn fast() -> Self {
263 Self {
264 dt: 0.1,
265 dtmin: 1e-4,
266 dtmax: 1.0,
267 abstol: 1e-2,
268 reltol: 1e-2,
269 maxiters: 1_000,
270 adaptive: true,
271 }
272 }
273
274 pub fn accurate() -> Self {
276 Self {
277 dt: 0.001,
278 dtmin: 1e-8,
279 dtmax: 0.1,
280 abstol: 1e-9,
281 reltol: 1e-6,
282 maxiters: 1_000_000,
283 adaptive: true,
284 }
285 }
286
287 pub fn stiff() -> Self {
289 Self {
290 dt: 0.001,
291 dtmin: 1e-10,
292 dtmax: 0.01,
293 abstol: 1e-8,
294 reltol: 1e-5,
295 maxiters: 500_000,
296 adaptive: true,
297 }
298 }
299
300 pub fn game_ai() -> Self {
302 Self {
303 dt: 0.1,
304 dtmin: 1e-3,
305 dtmax: 1.0,
306 abstol: 1e-2,
307 reltol: 1e-2,
308 maxiters: 500,
309 adaptive: true,
310 }
311 }
312
313 pub fn epidemic() -> Self {
315 Self {
316 dt: 0.01,
317 dtmin: 1e-6,
318 dtmax: 0.5,
319 abstol: 1e-6,
320 reltol: 1e-4,
321 maxiters: 200_000,
322 adaptive: true,
323 }
324 }
325
326 pub fn workflow() -> Self {
328 Self {
329 dt: 0.1,
330 dtmin: 1e-4,
331 dtmax: 10.0,
332 abstol: 1e-4,
333 reltol: 1e-3,
334 maxiters: 50_000,
335 adaptive: true,
336 }
337 }
338
339 pub fn long_run() -> Self {
341 Self {
342 dt: 0.1,
343 dtmin: 1e-4,
344 dtmax: 10.0,
345 abstol: 1e-5,
346 reltol: 1e-3,
347 maxiters: 500_000,
348 adaptive: true,
349 }
350 }
351}
352
353pub fn copy_state(s: &State) -> State {
355 s.clone()
356}
357
358fn vec_to_state(v: &[f64], labels: &[String]) -> State {
360 labels
361 .iter()
362 .enumerate()
363 .map(|(i, label)| (label.clone(), v[i]))
364 .collect()
365}
366
367pub fn solve(prob: &Problem, solver: &Solver, opts: &Options) -> Solution {
371 let dt = opts.dt;
372 let dtmin = opts.dtmin;
373 let dtmax = opts.dtmax;
374 let abstol = opts.abstol;
375 let reltol = opts.reltol;
376 let maxiters = opts.maxiters;
377 let adaptive = opts.adaptive;
378
379 let t0 = prob.tspan[0];
380 let tf = prob.tspan[1];
381 let f = &prob.vec_f;
382 let n = prob.vec_u0.len();
383
384 let mut t_out = vec![t0];
385 let mut u_out: Vec<Vec<f64>> = vec![prob.vec_u0.clone()];
386 let mut tcur = t0;
387 let mut ucur = prob.vec_u0.clone();
388 let mut dtcur = dt;
389 let mut nsteps = 0usize;
390
391 while tcur < tf && nsteps < maxiters {
392 if tcur + dtcur > tf {
394 dtcur = tf - tcur;
395 }
396
397 let num_stages = solver.c.len();
399 let mut k: Vec<Vec<f64>> = Vec::with_capacity(num_stages);
400 k.push(f(tcur, &ucur));
401
402 for stage in 1..num_stages {
403 let tstage = tcur + solver.c[stage] * dtcur;
404 let mut ustage = ucur.clone();
405 for j in 0..stage {
406 let aj = if stage < solver.a.len() && j < solver.a[stage].len() {
407 solver.a[stage][j]
408 } else {
409 0.0
410 };
411 if aj != 0.0 {
412 let scale = dtcur * aj;
413 for i in 0..n {
414 ustage[i] += scale * k[j][i];
415 }
416 }
417 }
418 k.push(f(tstage, &ustage));
419 }
420
421 let mut unext = ucur.clone();
423 for j in 0..solver.b.len() {
424 if solver.b[j] != 0.0 {
425 let scale = dtcur * solver.b[j];
426 for i in 0..n {
427 unext[i] += scale * k[j][i];
428 }
429 }
430 }
431
432 let mut err = 0.0;
434 if adaptive {
435 for i in 0..n {
436 let mut errest = 0.0;
437 for j in 0..solver.b_hat.len() {
438 errest += dtcur * solver.b_hat[j] * k[j][i];
439 }
440 let uc = ucur[i];
441 let un = unext[i];
442 let mut scale = abstol + reltol * uc.abs().max(un.abs());
443 if scale == 0.0 {
444 scale = abstol;
445 }
446 let val = errest.abs() / scale;
447 if val > err {
448 err = val;
449 }
450 }
451 }
452
453 if !adaptive || err <= 1.0 || dtcur <= dtmin {
455 tcur += dtcur;
456 ucur = unext;
457 t_out.push(tcur);
458 u_out.push(ucur.clone());
459 nsteps += 1;
460
461 if adaptive && err > 0.0 {
462 let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
463 let factor = factor.min(5.0);
464 dtcur = dtmax.min(dtmin.max(dtcur * factor));
465 }
466 } else {
467 let factor = 0.9 * (1.0 / err).powf(1.0 / (solver.order as f64 + 1.0));
468 let factor = factor.max(0.1);
469 dtcur = dtmin.max(dtcur * factor);
470 }
471 }
472
473 let state_u: Vec<State> = u_out
475 .iter()
476 .map(|v| vec_to_state(v, &prob.state_labels))
477 .collect();
478
479 Solution {
480 t: t_out,
481 u: state_u,
482 state_labels: prob.state_labels.clone(),
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use crate::methods;
490
491 #[test]
492 fn test_simple_decay() {
493 let net = PetriNet::build()
495 .place("A", 10.0)
496 .place("B", 0.0)
497 .transition("t1")
498 .arc("A", "t1", 1.0)
499 .arc("t1", "B", 1.0)
500 .done();
501
502 let state = net.set_state(None);
503 let rates = net.set_rates(None);
504 let prob = Problem::new(net, state, [0.0, 10.0], rates);
505 let sol = solve(&prob, &methods::tsit5(), &Options::default_opts());
506
507 let final_state = sol.get_final_state().unwrap();
508 let total = final_state["A"] + final_state["B"];
509 assert!((total - 10.0).abs() < 0.1);
511 }
512}