differential_equations/methods/dirk/fixed/
ordinary.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, cubic_hermite_interpolate},
6 linalg::Matrix,
7 methods::{DiagonallyImplicitRungeKutta, Fixed, Ordinary},
8 ode::{ODE, OrdinaryNumericalMethod},
9 stats::Evals,
10 status::Status,
11 traits::{Real, State},
12 utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16 OrdinaryNumericalMethod<T, Y> for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
17{
18 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19 where
20 F: ODE<T, Y>,
21 {
22 let mut evals = Evals::new();
23
24 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
26 Ok(h0) => self.h = h0,
27 Err(status) => return Err(status),
28 }
29
30 self.stiffness_counter = 0;
32 self.newton_iterations = 0;
33 self.jacobian_evaluations = 0;
34 self.lu_decompositions = 0;
35
36 self.t = t0;
38 self.y = *y0;
39 ode.diff(self.t, &self.y, &mut self.dydt);
40 evals.function += 1;
41
42 self.t_prev = self.t;
44 self.y_prev = self.y;
45 self.dydt_prev = self.dydt;
46
47 let dim = y0.len();
49 self.jacobian = Matrix::zeros(dim, dim);
50 self.z = *y0;
51 self.jacobian_age = 0;
52
53 self.status = Status::Initialized;
55
56 Ok(evals)
57 }
58
59 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
60 where
61 F: ODE<T, Y>,
62 {
63 let mut evals = Evals::new();
64
65 if self.steps >= self.max_steps {
67 self.status = Status::Error(Error::MaxSteps {
68 t: self.t,
69 y: self.y,
70 });
71 return Err(Error::MaxSteps {
72 t: self.t,
73 y: self.y,
74 });
75 }
76 self.steps += 1;
77
78 let dim = self.y.len();
79
80 for stage in 0..self.stages {
82 let mut rhs = self.y;
84 for j in 0..stage {
85 rhs += self.k[j] * (self.a[stage][j] * self.h);
86 }
87
88 self.z = self.y;
90
91 let mut newton_converged = false;
93 let mut newton_iter = 0;
94 let mut increment_norm = T::infinity();
95
96 while !newton_converged && newton_iter < self.max_newton_iter {
97 newton_iter += 1;
98 self.newton_iterations += 1;
99 evals.newton += 1;
100
101 let t_stage = self.t + self.c[stage] * self.h;
103 let mut f_stage = Y::zeros();
104 ode.diff(t_stage, &self.z, &mut f_stage);
105 evals.function += 1;
106
107 let residual = self.z - rhs - f_stage * (self.a[stage][stage] * self.h);
109
110 let mut residual_norm = T::zero();
112 self.rhs_newton = -residual;
113 for i in 0..dim {
114 residual_norm = residual_norm.max(residual.get(i).abs());
115 }
116
117 if residual_norm < self.newton_tol {
119 newton_converged = true;
120 break;
121 }
122
123 if newton_iter > 1 && increment_norm < self.newton_tol {
125 newton_converged = true;
126 break;
127 }
128
129 if newton_iter == 1 || self.jacobian_age > 3 {
131 ode.jacobian(t_stage, &self.z, &mut self.jacobian);
132 evals.jacobian += 1;
133 self.jacobian_age = 0;
134
135 self.jacobian
137 .component_mul_mut(-self.h * self.a[stage][stage]);
138 self.jacobian += Matrix::identity(dim);
139 }
140 self.jacobian_age += 1;
141
142 self.delta_z = self.jacobian.lin_solve(self.rhs_newton).unwrap();
144 self.lu_decompositions += 1;
145
146 increment_norm = T::zero();
148 self.z += self.delta_z;
149 for row_idx in 0..dim {
150 increment_norm = increment_norm.max(self.delta_z.get(row_idx).abs());
152 }
153 }
154
155 if !newton_converged {
157 self.status = Status::Error(Error::Stiffness {
158 t: self.t,
159 y: self.y,
160 });
161 return Err(Error::Stiffness {
162 t: self.t,
163 y: self.y,
164 });
165 }
166
167 let t_stage = self.t + self.c[stage] * self.h;
169 ode.diff(t_stage, &self.z, &mut self.k[stage]);
170 evals.function += 1;
171 }
172
173 let mut y_new = self.y;
175 for i in 0..self.stages {
176 y_new += self.k[i] * (self.b[i] * self.h);
177 }
178
179 self.status = Status::Solving;
181
182 self.t_prev = self.t;
184 self.y_prev = self.y;
185 self.dydt_prev = self.dydt;
186 self.h_prev = self.h;
187
188 self.t += self.h;
189 self.y = y_new;
190
191 ode.diff(self.t, &self.y, &mut self.dydt);
193 evals.function += 1;
194
195 Ok(evals)
196 }
197
198 fn t(&self) -> T {
199 self.t
200 }
201 fn y(&self) -> &Y {
202 &self.y
203 }
204 fn t_prev(&self) -> T {
205 self.t_prev
206 }
207 fn y_prev(&self) -> &Y {
208 &self.y_prev
209 }
210 fn h(&self) -> T {
211 self.h
212 }
213 fn set_h(&mut self, h: T) {
214 self.h = h;
215 }
216 fn status(&self) -> &Status<T, Y> {
217 &self.status
218 }
219 fn set_status(&mut self, status: Status<T, Y>) {
220 self.status = status;
221 }
222}
223
224impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
225 for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
226{
227 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
228 if t_interp < self.t_prev || t_interp > self.t {
230 return Err(Error::OutOfBounds {
231 t_interp,
232 t_prev: self.t_prev,
233 t_curr: self.t,
234 });
235 }
236
237 let y_interp = cubic_hermite_interpolate(
239 self.t_prev,
240 self.t,
241 &self.y_prev,
242 &self.y,
243 &self.dydt_prev,
244 &self.dydt,
245 t_interp,
246 );
247
248 Ok(y_interp)
249 }
250}