1use pflow_core::State;
4
5use crate::methods;
6use crate::ode::{copy_state, solve, Options, Problem, Solution};
7
8pub fn implicit_euler(prob: &Problem, opts: &Options) -> Solution {
12 let dt = opts.dt;
13 let maxiters = opts.maxiters;
14 let abstol = opts.abstol;
15
16 let t0 = prob.tspan[0];
17 let tf = prob.tspan[1];
18 let f = &prob.f;
19 let state_labels = &prob.state_labels;
20
21 let mut t_out = vec![t0];
22 let mut u_out = vec![copy_state(&prob.u0)];
23 let mut tcur = t0;
24 let mut ucur = copy_state(&prob.u0);
25 let mut nsteps = 0usize;
26
27 let max_fixed_point = 50;
28 let fixed_point_tol = abstol * 10.0;
29
30 while tcur < tf && nsteps < maxiters {
31 let mut dtcur = dt;
32 if tcur + dtcur > tf {
33 dtcur = tf - tcur;
34 }
35
36 let tnext = tcur + dtcur;
37
38 let mut unext = copy_state(&ucur);
40 let du = f(tcur, &ucur);
41 for key in state_labels {
42 if let (Some(un), Some(d)) = (unext.get_mut(key), du.get(key)) {
43 *un += dtcur * d;
44 }
45 }
46
47 for _ in 0..max_fixed_point {
49 let mut unew = copy_state(&ucur);
50 let dunext = f(tnext, &unext);
51 for key in state_labels {
52 if let (Some(un), Some(d)) = (unew.get_mut(key), dunext.get(key)) {
53 *un += dtcur * d;
54 }
55 }
56
57 let mut max_diff = 0.0f64;
58 for key in state_labels {
59 let diff = (unew.get(key).unwrap_or(&0.0) - unext.get(key).unwrap_or(&0.0)).abs();
60 max_diff = max_diff.max(diff);
61 }
62
63 unext = unew;
64 if max_diff < fixed_point_tol {
65 break;
66 }
67 }
68
69 tcur = tnext;
70 ucur = unext;
71 t_out.push(tcur);
72 u_out.push(copy_state(&ucur));
73 nsteps += 1;
74 }
75
76 Solution {
77 t: t_out,
78 u: u_out,
79 state_labels: state_labels.clone(),
80 }
81}
82
83pub fn detect_stiffness(prob: &Problem) -> bool {
85 let du = (prob.f)(prob.tspan[0], &prob.u0);
86
87 let mut max_du = 0.0f64;
88 let mut min_du = f64::MAX;
89
90 for v in du.values() {
91 let abs_v = v.abs();
92 if abs_v > 1e-10 {
93 max_du = max_du.max(abs_v);
94 min_du = min_du.min(abs_v);
95 }
96 }
97
98 if min_du < 1e-10 || max_du < 1e-10 {
99 return false;
100 }
101
102 max_du / min_du > 1000.0
103}
104
105pub fn solve_implicit(prob: &Problem, opts: &Options) -> Solution {
107 if detect_stiffness(prob) {
108 let implicit_opts = Options {
109 adaptive: false,
110 ..opts.clone()
111 };
112 implicit_euler(prob, &implicit_opts)
113 } else {
114 solve(prob, &methods::tsit5(), opts)
115 }
116}
117
118pub fn trbdf2(prob: &Problem, opts: &Options) -> Solution {
120 let dt = opts.dt;
121 let maxiters = opts.maxiters;
122 let abstol = opts.abstol;
123
124 let t0 = prob.tspan[0];
125 let tf = prob.tspan[1];
126 let f = &prob.f;
127 let state_labels = &prob.state_labels;
128
129 let mut t_out = vec![t0];
130 let mut u_out = vec![copy_state(&prob.u0)];
131 let mut tcur = t0;
132 let mut ucur = copy_state(&prob.u0);
133 let mut nsteps = 0usize;
134
135 let gamma = 2.0 - f64::sqrt(2.0);
136 let max_fixed_point = 50;
137 let fixed_point_tol = abstol * 10.0;
138
139 while tcur < tf && nsteps < maxiters {
140 let mut dtcur = dt;
141 if tcur + dtcur > tf {
142 dtcur = tf - tcur;
143 }
144
145 let tgamma = tcur + gamma * dtcur;
147 let mut ugamma = copy_state(&ucur);
148 let du0 = f(tcur, &ucur);
149
150 for key in state_labels {
151 if let (Some(ug), Some(d)) = (ugamma.get_mut(key), du0.get(key)) {
152 *ug += gamma * dtcur * d;
153 }
154 }
155
156 for _ in 0..max_fixed_point {
157 let dugamma = f(tgamma, &ugamma);
158 let mut unew = copy_state(&ucur);
159 for key in state_labels {
160 if let (Some(un), Some(d0), Some(dg)) =
161 (unew.get_mut(key), du0.get(key), dugamma.get(key))
162 {
163 *un += 0.5 * gamma * dtcur * (d0 + dg);
164 }
165 }
166
167 let mut max_diff = 0.0f64;
168 for key in state_labels {
169 let diff = (unew.get(key).unwrap_or(&0.0) - ugamma.get(key).unwrap_or(&0.0)).abs();
170 max_diff = max_diff.max(diff);
171 }
172
173 ugamma = unew;
174 if max_diff < fixed_point_tol {
175 break;
176 }
177 }
178
179 let tnext = tcur + dtcur;
181 let mut unext = copy_state(&ugamma);
182
183 let dugamma = f(tgamma, &ugamma);
184 for key in state_labels {
185 if let (Some(un), Some(dg)) = (unext.get_mut(key), dugamma.get(key)) {
186 *un += (1.0 - gamma) * dtcur * dg;
187 }
188 }
189
190 let w1 = 1.0 / (gamma * (2.0 - gamma));
191 let w0 = -((1.0 - gamma) * (1.0 - gamma)) / (gamma * (2.0 - gamma));
192 let wf = (1.0 - gamma) / (2.0 - gamma);
193
194 for _ in 0..max_fixed_point {
195 let dunext = f(tnext, &unext);
196 let mut unew: State = State::new();
197 for key in state_labels {
198 let ug = ugamma.get(key).copied().unwrap_or(0.0);
199 let uc = ucur.get(key).copied().unwrap_or(0.0);
200 let dn = dunext.get(key).copied().unwrap_or(0.0);
201 unew.insert(key.clone(), w1 * ug + w0 * uc + wf * dtcur * dn);
202 }
203
204 let mut max_diff = 0.0f64;
205 for key in state_labels {
206 let diff = (unew.get(key).unwrap_or(&0.0) - unext.get(key).unwrap_or(&0.0)).abs();
207 max_diff = max_diff.max(diff);
208 }
209
210 unext = unew;
211 if max_diff < fixed_point_tol {
212 break;
213 }
214 }
215
216 tcur = tnext;
217 ucur = unext;
218 t_out.push(tcur);
219 u_out.push(copy_state(&ucur));
220 nsteps += 1;
221 }
222
223 Solution {
224 t: t_out,
225 u: u_out,
226 state_labels: state_labels.clone(),
227 }
228}