differential_equations/methods/dirk/fixed/
ordinary.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, cubic_hermite_interpolate},
6 linalg::Matrix,
7 methods::{DiagonallyImplicitRungeKutta, Fixed, Ordinary},
8 ode::{ODE, OrdinaryNumericalMethod},
9 stats::Evals,
10 status::Status,
11 traits::{CallBackData, Real, State},
12 utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
16 OrdinaryNumericalMethod<T, Y, D>
17 for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, D, O, S, I>
18{
19 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
20 where
21 F: ODE<T, Y, D>,
22 {
23 let mut evals = Evals::new();
24
25 match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
27 Ok(h0) => self.h = h0,
28 Err(status) => return Err(status),
29 }
30
31 self.stiffness_counter = 0;
33 self.newton_iterations = 0;
34 self.jacobian_evaluations = 0;
35 self.lu_decompositions = 0;
36
37 self.t = t0;
39 self.y = *y0;
40 ode.diff(self.t, &self.y, &mut self.dydt);
41 evals.function += 1;
42
43 self.t_prev = self.t;
45 self.y_prev = self.y;
46 self.dydt_prev = self.dydt;
47
48 let dim = y0.len();
50 self.jacobian = Matrix::zeros(dim);
51 self.z = *y0;
52 self.jacobian_age = 0;
53
54 self.status = Status::Initialized;
56
57 Ok(evals)
58 }
59
60 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
61 where
62 F: ODE<T, Y, D>,
63 {
64 let mut evals = Evals::new();
65
66 if self.steps >= self.max_steps {
68 self.status = Status::Error(Error::MaxSteps {
69 t: self.t,
70 y: self.y,
71 });
72 return Err(Error::MaxSteps {
73 t: self.t,
74 y: self.y,
75 });
76 }
77 self.steps += 1;
78
79 let dim = self.y.len();
80
81 for stage in 0..self.stages {
83 let mut rhs = self.y;
85 for j in 0..stage {
86 rhs += self.k[j] * (self.a[stage][j] * self.h);
87 }
88
89 self.z = self.y;
91
92 let mut newton_converged = false;
94 let mut newton_iter = 0;
95 let mut increment_norm = T::infinity();
96
97 while !newton_converged && newton_iter < self.max_newton_iter {
98 newton_iter += 1;
99 self.newton_iterations += 1;
100 evals.newton += 1;
101
102 let t_stage = self.t + self.c[stage] * self.h;
104 let mut f_stage = Y::zeros();
105 ode.diff(t_stage, &self.z, &mut f_stage);
106 evals.function += 1;
107
108 let residual = self.z - rhs - f_stage * (self.a[stage][stage] * self.h);
110
111 let mut residual_norm = T::zero();
113 self.rhs_newton = -residual;
114 for i in 0..dim {
115 residual_norm = residual_norm.max(residual.get(i).abs());
116 }
117
118 if residual_norm < self.newton_tol {
120 newton_converged = true;
121 break;
122 }
123
124 if newton_iter > 1 && increment_norm < self.newton_tol {
126 newton_converged = true;
127 break;
128 }
129
130 if newton_iter == 1 || self.jacobian_age > 3 {
132 ode.jacobian(t_stage, &self.z, &mut self.jacobian);
133 evals.jacobian += 1;
134 self.jacobian_age = 0;
135
136 self.jacobian
138 .component_mul_mut(-self.h * self.a[stage][stage]);
139 self.jacobian += Matrix::identity(dim);
140 }
141 self.jacobian_age += 1;
142
143 self.delta_z = self.jacobian.lin_solve(self.rhs_newton);
145 self.lu_decompositions += 1;
146
147 increment_norm = T::zero();
149 self.z = self.z + self.delta_z;
150 for row_idx in 0..dim {
151 increment_norm = increment_norm.max(self.delta_z.get(row_idx).abs());
153 }
154 }
155
156 if !newton_converged {
158 self.status = Status::Error(Error::Stiffness {
159 t: self.t,
160 y: self.y,
161 });
162 return Err(Error::Stiffness {
163 t: self.t,
164 y: self.y,
165 });
166 }
167
168 let t_stage = self.t + self.c[stage] * self.h;
170 ode.diff(t_stage, &self.z, &mut self.k[stage]);
171 evals.function += 1;
172 }
173
174 let mut y_new = self.y;
176 for i in 0..self.stages {
177 y_new += self.k[i] * (self.b[i] * self.h);
178 }
179
180 self.status = Status::Solving;
182
183 self.t_prev = self.t;
185 self.y_prev = self.y;
186 self.dydt_prev = self.dydt;
187 self.h_prev = self.h;
188
189 self.t += self.h;
190 self.y = y_new;
191
192 ode.diff(self.t, &self.y, &mut self.dydt);
194 evals.function += 1;
195
196 Ok(evals)
197 }
198
199 fn t(&self) -> T {
200 self.t
201 }
202 fn y(&self) -> &Y {
203 &self.y
204 }
205 fn t_prev(&self) -> T {
206 self.t_prev
207 }
208 fn y_prev(&self) -> &Y {
209 &self.y_prev
210 }
211 fn h(&self) -> T {
212 self.h
213 }
214 fn set_h(&mut self, h: T) {
215 self.h = h;
216 }
217 fn status(&self) -> &Status<T, Y, D> {
218 &self.status
219 }
220 fn set_status(&mut self, status: Status<T, Y, D>) {
221 self.status = status;
222 }
223}
224
225impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
226 Interpolation<T, Y> for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, D, O, S, I>
227{
228 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
229 if t_interp < self.t_prev || t_interp > self.t {
231 return Err(Error::OutOfBounds {
232 t_interp,
233 t_prev: self.t_prev,
234 t_curr: self.t,
235 });
236 }
237
238 let y_interp = cubic_hermite_interpolate(
240 self.t_prev,
241 self.t,
242 &self.y_prev,
243 &self.y,
244 &self.dydt_prev,
245 &self.dydt,
246 t_interp,
247 );
248
249 Ok(y_interp)
250 }
251}