differential_equations/methods/irk/fixed/
ordinary.rs1use super::{ImplicitRungeKutta, Ordinary, Fixed};
4use crate::{
5 Error, Status,
6 alias::Evals,
7 interpolate::{Interpolation, cubic_hermite_interpolate},
8 ode::{OrdinaryNumericalMethod, ODE},
9 traits::{CallBackData, Real, State},
10 utils::validate_step_size_parameters,
11};
12use nalgebra::{DMatrix, DVector};
13
14impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> OrdinaryNumericalMethod<T, V, D> for ImplicitRungeKutta<Ordinary, Fixed, T, V, D, O, S, I> {
15 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
16 where
17 F: ODE<T, V, D>,
18 {
19 let mut evals = Evals::new();
20
21 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
23 Ok(h0) => self.h = h0,
25 Err(status) => return Err(status),
26 }
27
28 self.stiffness_counter = 0;
30 self.newton_iterations = 0;
31 self.jacobian_evaluations = 0;
32 self.lu_decompositions = 0;
33
34 self.t = t0;
36 self.y = *y0;
37 ode.diff(self.t, &self.y, &mut self.dydt);
38 evals.fcn += 1;
39
40 self.t_prev = self.t;
42 self.y_prev = self.y;
43 self.dydt_prev = self.dydt;
44
45 let dim = y0.len();
47 let newton_system_size = self.stages * dim;
48 self.jacobian_matrix = DMatrix::zeros(dim, dim);
49 self.newton_matrix = DMatrix::zeros(newton_system_size, newton_system_size);
50 self.rhs_newton = DVector::zeros(newton_system_size);
51 self.delta_k_vec = DVector::zeros(newton_system_size);
52 self.jacobian_age = 0;
53
54 self.status = Status::Initialized;
56
57 Ok(evals)
58 }
59
60 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
61 where
62 F: ODE<T, V, D>,
63 {
64 let mut evals = Evals::new();
65
66 if self.steps >= self.max_steps {
68 self.status = Status::Error(Error::MaxSteps {
69 t: self.t, y: self.y
70 });
71 return Err(Error::MaxSteps {
72 t: self.t, y: self.y
73 });
74 }
75 self.steps += 1;
76
77 for i in 0..self.stages {
79 self.y_stages[i] = self.y;
80 }
81
82 let mut newton_converged = false;
84 let mut newton_iter = 0;
85 let dim = self.y.len();
86
87 ode.jacobian(self.t, &self.y, &mut self.jacobian_matrix);
89 evals.jac += 1;
90
91 while !newton_converged && newton_iter < self.max_newton_iter {
92 newton_iter += 1;
93 self.newton_iterations += 1;
94
95 for i in 0..self.stages {
97 ode.diff(self.t + self.c[i] * self.h, &self.y_stages[i], &mut self.k[i]);
98 }
99 evals.fcn += self.stages;
100
101 for i in 0..self.stages {
104 self.y_stages[i] = self.y;
106 for j in 0..self.stages {
107 self.y_stages[i] += self.k[j] * (self.a[i][j] * self.h);
108 }
109
110 let mut f_at_stage = V::zeros();
112 ode.diff(self.t + self.c[i] * self.h, &self.y_stages[i], &mut f_at_stage);
113 evals.fcn += 1;
114
115 for row_idx in 0..dim {
118 self.rhs_newton[i * dim + row_idx] = f_at_stage.get(row_idx) - self.k[i].get(row_idx);
119 }
120 }
121
122 for i in 0..self.stages {
125 for j in 0..self.stages {
126 let scale_factor = -self.h * self.a[i][j];
127 for r in 0..dim {
128 for c_col in 0..dim {
129 self.newton_matrix[(i * dim + r, j * dim + c_col)] =
130 self.jacobian_matrix[(r, c_col)] * scale_factor;
131 }
132 }
133
134 if i == j {
136 for d_idx in 0..dim {
137 self.newton_matrix[(i * dim + d_idx, j * dim + d_idx)] += T::one();
138 }
139 }
140 }
141 }
142
143 let lu_decomp = nalgebra::LU::new(self.newton_matrix.clone());
145 if let Some(solution) = lu_decomp.solve(&self.rhs_newton) {
146 self.delta_k_vec.copy_from(&solution);
147 } else {
148 newton_converged = false;
150 break;
151 }
152
153 let mut norm_delta_k_sq = T::zero();
155 for i in 0..self.stages {
156 for row_idx in 0..dim {
157 let delta_val = self.delta_k_vec[i * dim + row_idx];
158 let current_val = self.k[i].get(row_idx);
159 self.k[i].set(row_idx, current_val + delta_val);
160 norm_delta_k_sq += delta_val * delta_val;
161 }
162 }
163
164 if norm_delta_k_sq < self.newton_tol * self.newton_tol {
166 newton_converged = true;
167 }
168 }
169
170 if !newton_converged {
172 self.status = Status::Error(Error::Stiffness {
173 t: self.t,
174 y: self.y,
175 });
176 return Err(Error::Stiffness {
177 t: self.t,
178 y: self.y,
179 });
180 }
181
182 for i in 0..self.stages {
184 self.y_stages[i] = self.y;
186 for j in 0..self.stages {
187 self.y_stages[i] += self.k[j] * (self.a[i][j] * self.h);
188 }
189 ode.diff(self.t + self.c[i] * self.h, &self.y_stages[i], &mut self.k[i]);
190 }
191 evals.fcn += self.stages;
192
193 let mut y_new = self.y;
195 for i in 0..self.stages {
196 y_new += self.k[i] * (self.b[i] * self.h);
197 }
198
199 self.status = Status::Solving;
201
202 self.t_prev = self.t;
204 self.y_prev = self.y;
205 self.dydt_prev = self.dydt;
206 self.h_prev = self.h;
207
208 self.t += self.h;
210 self.y = y_new;
211
212 ode.diff(self.t, &self.y, &mut self.dydt);
214 evals.fcn += 1;
215
216 Ok(evals)
217 }
218
219 fn t(&self) -> T { self.t }
220 fn y(&self) -> &V { &self.y }
221 fn t_prev(&self) -> T { self.t_prev }
222 fn y_prev(&self) -> &V { &self.y_prev }
223 fn h(&self) -> T { self.h }
224 fn set_h(&mut self, h: T) { self.h = h; }
225 fn status(&self) -> &Status<T, V, D> { &self.status }
226 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
227}
228
229impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ImplicitRungeKutta<Ordinary, Fixed, T, V, D, O, S, I> {
230 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
231 if t_interp < self.t_prev || t_interp > self.t {
233 return Err(Error::OutOfBounds {
234 t_interp,
235 t_prev: self.t_prev,
236 t_curr: self.t
237 });
238 }
239
240 let y_interp = cubic_hermite_interpolate(
242 self.t_prev,
243 self.t,
244 &self.y_prev,
245 &self.y,
246 &self.dydt_prev,
247 &self.dydt,
248 t_interp
249 );
250
251 Ok(y_interp)
252 }
253}