Skip to main content

pflow_solver/
implicit.rs

1//! Implicit ODE methods for stiff systems.
2
3use pflow_core::State;
4
5use crate::methods;
6use crate::ode::{copy_state, solve, Options, Problem, Solution};
7
8/// Backward Euler method for stiff ODEs.
9///
10/// Uses fixed-point iteration to solve the implicit equation.
11pub fn implicit_euler(prob: &Problem, opts: &Options) -> Solution {
12    let dt = opts.dt;
13    let maxiters = opts.maxiters;
14    let abstol = opts.abstol;
15
16    let t0 = prob.tspan[0];
17    let tf = prob.tspan[1];
18    let f = &prob.f;
19    let state_labels = &prob.state_labels;
20
21    let mut t_out = vec![t0];
22    let mut u_out = vec![copy_state(&prob.u0)];
23    let mut tcur = t0;
24    let mut ucur = copy_state(&prob.u0);
25    let mut nsteps = 0usize;
26
27    let max_fixed_point = 50;
28    let fixed_point_tol = abstol * 10.0;
29
30    while tcur < tf && nsteps < maxiters {
31        let mut dtcur = dt;
32        if tcur + dtcur > tf {
33            dtcur = tf - tcur;
34        }
35
36        let tnext = tcur + dtcur;
37
38        // Initial guess: explicit Euler
39        let mut unext = copy_state(&ucur);
40        let du = f(tcur, &ucur);
41        for key in state_labels {
42            if let (Some(un), Some(d)) = (unext.get_mut(key), du.get(key)) {
43                *un += dtcur * d;
44            }
45        }
46
47        // Fixed-point iteration
48        for _ in 0..max_fixed_point {
49            let mut unew = copy_state(&ucur);
50            let dunext = f(tnext, &unext);
51            for key in state_labels {
52                if let (Some(un), Some(d)) = (unew.get_mut(key), dunext.get(key)) {
53                    *un += dtcur * d;
54                }
55            }
56
57            let mut max_diff = 0.0f64;
58            for key in state_labels {
59                let diff = (unew.get(key).unwrap_or(&0.0) - unext.get(key).unwrap_or(&0.0)).abs();
60                max_diff = max_diff.max(diff);
61            }
62
63            unext = unew;
64            if max_diff < fixed_point_tol {
65                break;
66            }
67        }
68
69        tcur = tnext;
70        ucur = unext;
71        t_out.push(tcur);
72        u_out.push(copy_state(&ucur));
73        nsteps += 1;
74    }
75
76    Solution {
77        t: t_out,
78        u: u_out,
79        state_labels: state_labels.clone(),
80    }
81}
82
83/// Detects whether the problem is stiff using a heuristic.
84pub fn detect_stiffness(prob: &Problem) -> bool {
85    let du = (prob.f)(prob.tspan[0], &prob.u0);
86
87    let mut max_du = 0.0f64;
88    let mut min_du = f64::MAX;
89
90    for v in du.values() {
91        let abs_v = v.abs();
92        if abs_v > 1e-10 {
93            max_du = max_du.max(abs_v);
94            min_du = min_du.min(abs_v);
95        }
96    }
97
98    if min_du < 1e-10 || max_du < 1e-10 {
99        return false;
100    }
101
102    max_du / min_du > 1000.0
103}
104
105/// Chooses between explicit and implicit methods based on stiffness detection.
106pub fn solve_implicit(prob: &Problem, opts: &Options) -> Solution {
107    if detect_stiffness(prob) {
108        let implicit_opts = Options {
109            adaptive: false,
110            ..opts.clone()
111        };
112        implicit_euler(prob, &implicit_opts)
113    } else {
114        solve(prob, &methods::tsit5(), opts)
115    }
116}
117
118/// TR-BDF2 method: two-stage implicit method combining trapezoidal rule with BDF2.
119pub fn trbdf2(prob: &Problem, opts: &Options) -> Solution {
120    let dt = opts.dt;
121    let maxiters = opts.maxiters;
122    let abstol = opts.abstol;
123
124    let t0 = prob.tspan[0];
125    let tf = prob.tspan[1];
126    let f = &prob.f;
127    let state_labels = &prob.state_labels;
128
129    let mut t_out = vec![t0];
130    let mut u_out = vec![copy_state(&prob.u0)];
131    let mut tcur = t0;
132    let mut ucur = copy_state(&prob.u0);
133    let mut nsteps = 0usize;
134
135    let gamma = 2.0 - f64::sqrt(2.0);
136    let max_fixed_point = 50;
137    let fixed_point_tol = abstol * 10.0;
138
139    while tcur < tf && nsteps < maxiters {
140        let mut dtcur = dt;
141        if tcur + dtcur > tf {
142            dtcur = tf - tcur;
143        }
144
145        // Stage 1: Trapezoidal rule from t to t + gamma*dt
146        let tgamma = tcur + gamma * dtcur;
147        let mut ugamma = copy_state(&ucur);
148        let du0 = f(tcur, &ucur);
149
150        for key in state_labels {
151            if let (Some(ug), Some(d)) = (ugamma.get_mut(key), du0.get(key)) {
152                *ug += gamma * dtcur * d;
153            }
154        }
155
156        for _ in 0..max_fixed_point {
157            let dugamma = f(tgamma, &ugamma);
158            let mut unew = copy_state(&ucur);
159            for key in state_labels {
160                if let (Some(un), Some(d0), Some(dg)) =
161                    (unew.get_mut(key), du0.get(key), dugamma.get(key))
162                {
163                    *un += 0.5 * gamma * dtcur * (d0 + dg);
164                }
165            }
166
167            let mut max_diff = 0.0f64;
168            for key in state_labels {
169                let diff = (unew.get(key).unwrap_or(&0.0) - ugamma.get(key).unwrap_or(&0.0)).abs();
170                max_diff = max_diff.max(diff);
171            }
172
173            ugamma = unew;
174            if max_diff < fixed_point_tol {
175                break;
176            }
177        }
178
179        // Stage 2: BDF2-like step
180        let tnext = tcur + dtcur;
181        let mut unext = copy_state(&ugamma);
182
183        let dugamma = f(tgamma, &ugamma);
184        for key in state_labels {
185            if let (Some(un), Some(dg)) = (unext.get_mut(key), dugamma.get(key)) {
186                *un += (1.0 - gamma) * dtcur * dg;
187            }
188        }
189
190        let w1 = 1.0 / (gamma * (2.0 - gamma));
191        let w0 = -((1.0 - gamma) * (1.0 - gamma)) / (gamma * (2.0 - gamma));
192        let wf = (1.0 - gamma) / (2.0 - gamma);
193
194        for _ in 0..max_fixed_point {
195            let dunext = f(tnext, &unext);
196            let mut unew: State = State::new();
197            for key in state_labels {
198                let ug = ugamma.get(key).copied().unwrap_or(0.0);
199                let uc = ucur.get(key).copied().unwrap_or(0.0);
200                let dn = dunext.get(key).copied().unwrap_or(0.0);
201                unew.insert(key.clone(), w1 * ug + w0 * uc + wf * dtcur * dn);
202            }
203
204            let mut max_diff = 0.0f64;
205            for key in state_labels {
206                let diff = (unew.get(key).unwrap_or(&0.0) - unext.get(key).unwrap_or(&0.0)).abs();
207                max_diff = max_diff.max(diff);
208            }
209
210            unext = unew;
211            if max_diff < fixed_point_tol {
212                break;
213            }
214        }
215
216        tcur = tnext;
217        ucur = unext;
218        t_out.push(tcur);
219        u_out.push(copy_state(&ucur));
220        nsteps += 1;
221    }
222
223    Solution {
224        t: t_out,
225        u: u_out,
226        state_labels: state_labels.clone(),
227    }
228}