1use anyhow::anyhow;
2use tch::IndexOp;
3use tch::Tensor;
4use std::sync::Arc;
5
6pub mod optimizers;
7
8pub enum Solver {
9 Euler { step: f64 },
10 RK4 { step: f64 },
11 ImplicitEuler { step: f64, optimizer: Arc<dyn optimizers::Optimizer> },
12 GLRK4 { step: f64, optimizer: Arc<dyn optimizers::Optimizer> },
13 RKF45 { rtol: f64, atol: f64, min_step: f64, safety_factor: f64 },
14 ROW1 { step: f64 }
15}
16
17impl Solver {
18 pub fn solve(
19 &self,
20 f: tch::CModule,
21 x_span: Tensor,
22 y0: Tensor
23 ) -> anyhow::Result<(Tensor, Tensor)> {
24 if x_span.size() != [2] {
25 return Err(anyhow!("x_span must be of shape [2] but it has shape {:?}", x_span.size().as_slice()));
26 }
27 if y0.size().len() != 1 {
28 return Err(anyhow!("y0 must be a one-dimensional tensor but it has {} dimensions", y0.size().len()));
29 }
30 if x_span.device() != y0.device() {
31 return Err(anyhow!("x_span and y0 must reside on the same device. Device of x_span is {:?}. Device of y0 is {:?}", x_span.device(), y0.device()));
32 }
33 if x_span.kind() != tch::Kind::Double && x_span.kind() != tch::Kind::Float && x_span.kind() != tch::Kind::BFloat16 && x_span.kind() != tch::Kind::Half {
34 return Err(anyhow!("x_span is of unsupported kind {:?}", x_span.kind()));
35 }
36 if y0.kind() != tch::Kind::Double && y0.kind() != tch::Kind::Float && y0.kind() != tch::Kind::BFloat16 && y0.kind() != tch::Kind::Half {
37 return Err(anyhow!("y0 is of unsupported kind {:?}", y0.kind()));
38 }
39 if x_span.kind() != y0.kind() {
40 return Err(anyhow!("x_span and y0 must be of the same kind. Kind of x_span is {:?}. Kind of y0 is {:?}", x_span.kind(), y0.kind()));
41 }
42
43 match self {
44 Self::Euler { step } => solve_euler(f, x_span, y0, *step),
45 Self::RK4 { step } => solve_rk4(f, x_span, y0, *step),
46 Self::ImplicitEuler { step, optimizer } => solve_implicit_euler(f, x_span, y0, *step, optimizer.as_ref()),
47 Self::GLRK4 { step, optimizer } => solve_glrk4(f, x_span, y0, *step, optimizer.as_ref()),
48 Self::RKF45 { rtol, atol, min_step, safety_factor } => solve_rkf45(f, x_span, y0, *rtol, *atol, *min_step, *safety_factor),
49 Self::ROW1 { step } => solve_row1(f, x_span, y0, *step)
50 }
51 }
52}
53
54fn solve_euler(
56 f: tch::CModule,
57 x_span: Tensor,
58 y0: Tensor,
59 step: f64,
60) -> anyhow::Result<(Tensor, Tensor)> {
61 let x_start = x_span.i(0);
62 let x_end = x_span.i(1);
63
64 let mut x = x_start.unsqueeze(0);
65 let mut y = y0.unsqueeze(0);
66
67 let mut all_x = vec![x.copy()];
68 let mut all_y = vec![y.copy()];
69
70 let mut current_step = step;
71 while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
72 let remaining = &x_end - &x.squeeze();
73 if remaining.double_value(&[]) < current_step {
74 current_step = remaining.double_value(&[]);
75 }
76
77 let dy = f.forward_ts(&[x.squeeze().copy(), y.squeeze().copy()])?;
78 y = &y + current_step * &dy;
79 x = &x + current_step;
80
81 all_x.push(x.copy());
82 all_y.push(y.copy());
83 }
84
85 Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
86}
87
88fn solve_rk4(
90 f: tch::CModule,
91 x_span: Tensor,
92 y0: Tensor,
93 step: f64,
94) -> anyhow::Result<(Tensor, Tensor)> {
95 let x_start = x_span.i(0);
96 let x_end = x_span.i(1);
97
98 let mut x = x_start.unsqueeze(0);
99 let mut y = y0.unsqueeze(0);
100
101 let mut all_x = vec![x.copy()];
102 let mut all_y = vec![y.copy()];
103
104 let mut current_step = step;
105 while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
106 let remaining = &x_end - &x.squeeze();
107 if remaining.double_value(&[]) < current_step {
108 current_step = remaining.double_value(&[]);
109 }
110
111 let k1 = f.forward_ts(&[x.squeeze().copy(), y.squeeze().copy()])?;
112
113 let x_half: Tensor = &x + 0.5 * current_step;
114 let y_half: Tensor = &y + 0.5 * current_step * &k1;
115 let k2 = f.forward_ts(&[x_half.squeeze(), y_half.squeeze()])?;
116
117 let x_half_again: Tensor = &x + 0.5 * current_step;
118 let y_half_again: Tensor = &y + 0.5 * current_step * &k2;
119 let k3 = f.forward_ts(&[x_half_again.squeeze(), y_half_again.squeeze()])?;
120
121 let x_full = &x + current_step;
122 let y_full = &y + current_step * &k3;
123 let k4 = f.forward_ts(&[x_full.squeeze(), y_full.squeeze()])?;
124
125 let step_div_6 = current_step / 6.0;
126 let y_next = &y + step_div_6 * (&k1 + 2.0 * &k2 + 2.0 * &k3 + &k4);
127
128 x = &x + current_step;
129 y = y_next;
130
131 all_x.push(x.copy());
132 all_y.push(y.copy());
133 }
134
135 Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
136}
137
138fn solve_implicit_euler(
140 f: tch::CModule,
141 x_span: Tensor,
142 y0: Tensor,
143 step: f64,
144 optimizer: &dyn optimizers::Optimizer,
145) -> anyhow::Result<(Tensor, Tensor)> {
146 let x_start = x_span.i(0);
147 let x_end = x_span.i(1);
148
149 let mut x = x_start.unsqueeze(0);
150 let mut y = y0.unsqueeze(0);
151
152 let mut all_x = vec![x.copy()];
153 let mut all_y = vec![y.copy()];
154
155 let mut current_step = step;
156 while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
157 let remaining = &x_end - &x.squeeze();
158 if remaining.double_value(&[]) < current_step {
159 current_step = remaining.double_value(&[]);
160 }
161
162 let x_next = &x + current_step;
163 let y_prev = y.copy();
164
165 let y_next = optimizer.optimize(
166 &|y_next: &Tensor| {
167 let f_next = f
168 .forward_ts(&[x_next.squeeze().copy(), y_next.squeeze().copy()])
169 .unwrap();
170 let y_pred = &y_prev.squeeze() + current_step * &f_next;
171 (y_next - &y_pred).pow_tensor_scalar(2).sum(y_next.kind())
172 },
173 &(&y_prev.detach().squeeze()
174 + current_step * f.forward_ts(&[&x.squeeze(), &y_prev.squeeze()])?),
175 ).map_err( |err| {
176 anyhow!(format!("Optimizer failed with: {}", err))
177 })?;
178
179 y = y_next.unsqueeze(0);
180 x = x_next.copy();
181
182 all_x.push(x.copy());
183 all_y.push(y.copy());
184 }
185
186 Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
187}
188
189fn solve_glrk4(
191 f: tch::CModule,
192 x_span: Tensor,
193 y0: Tensor,
194 step: f64,
195 optimizer: &dyn optimizers::Optimizer,
196) -> anyhow::Result<(Tensor, Tensor)> {
197 let x_start = x_span.i(0);
198 let x_end = x_span.i(1);
199
200 let mut x = x_start.unsqueeze(0);
201 let mut y = y0.unsqueeze(0);
202
203 let mut all_x = vec![x.copy()];
204 let mut all_y = vec![y.copy()];
205
206 let mut current_step = step;
207 while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
208 let remaining = &x_end - &x.squeeze();
209 if remaining.double_value(&[]) < current_step {
210 current_step = remaining.double_value(&[]);
211 }
212
213 let k = f.forward_ts(&[x.squeeze().copy(), y.squeeze().copy()])?;
214
215 const C1: f64 = 0.2113248654f64;
216 const C2: f64 = 0.7886751346f64;
217 const A11: f64 = 0.25;
218 const A12: f64 = -0.03867513459f64;
219 const A21: f64 = 0.5386751346f64;
220 const A22: f64 = 0.25;
221
222 let first_k1k2_guess = Tensor::cat(
223 &[
224 f.forward_ts(&[
225 &x.squeeze() + C1 * current_step,
226 &y.squeeze() + C1 * current_step * &k,
227 ])?,
228 f.forward_ts(&[
229 &x.squeeze() + C2 * current_step,
230 &y.squeeze() + C2 * current_step * &k,
231 ])?,
232 ],
233 0,
234 );
235 let k1k2 = optimizer.optimize(
236 &|k1k2_guess| {
237 let diff1 = k1k2_guess.i(0..=1)
238 - f.forward_ts(&[
239 &x.squeeze() + C1 * current_step,
240 &y.squeeze()
241 + (A11 * k1k2_guess.i(0..=1) + A12 * k1k2_guess.i(2..=3))
242 * current_step,
243 ])
244 .unwrap();
245 let diff2 = k1k2_guess.i(2..=3)
246 - f.forward_ts(&[
247 &x.squeeze() + C2 * current_step,
248 &y.squeeze()
249 + (A21 * k1k2_guess.i(0..=1) + A22 * k1k2_guess.i(2..=3))
250 * current_step,
251 ])
252 .unwrap();
253
254 diff1.dot(&diff1) + diff2.dot(&diff2)
255 },
256 &first_k1k2_guess,
257 ).map_err( |err| {
258 anyhow!(format!("Optimizer failed with: {}", err))
259 })?;
260 assert!(k1k2.size().len() == 1);
261 assert!(k1k2.size()[0] == 4);
262
263 x = &x + current_step;
264 y = &y + current_step * (0.5 * k1k2.i(0..=1) + 0.5 * k1k2.i(2..=3));
265
266 all_x.push(x.copy());
267 all_y.push(y.copy());
268 }
269
270 Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
271}
272
273fn solve_rkf45(
275 f: tch::CModule,
276 x_span: Tensor,
277 y0: Tensor,
278 rtol: f64,
279 atol: f64,
280 min_step: f64,
281 safety_factor: f64,
282) -> anyhow::Result<(Tensor, Tensor)> {
283 let x_start = x_span.i(0);
284 let x_end = x_span.i(1);
285
286 let mut x = x_start.unsqueeze(0);
287 let mut y = y0.unsqueeze(0);
288
289 let mut all_x = vec![x.copy()];
290 let mut all_y = vec![y.copy()];
291
292 let mut step = (&x_end - &x_start) * 0.1;
293 let safety_factor_tensor = Tensor::from(safety_factor);
294
295 while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
296 let remaining = &x_end - &x.squeeze();
297 if remaining.lt_tensor(&step) == Tensor::from(true) {
298 step = remaining.copy();
299 }
300
301 let k1 = f.forward_ts(&[x.squeeze().copy(), y.squeeze().copy()])?;
302
303 let k2 = {
304 let x_step: Tensor = &x + 0.25 * &step;
305 let y_step: Tensor = &y + 0.25 * &step * &k1;
306 f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
307 };
308
309 let k3 = {
310 let x_step: Tensor = &x + 0.375 * &step;
311 let y_step: Tensor = &y + (0.09375 * &step * &k1) + (0.28125 * &step * &k2);
312 f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
313 };
314
315 let k4 = {
316 let x_step: Tensor = &x + (12.0 / 13.0) * &step;
317 let y_step: Tensor = &y
318 + (1932.0 / 2197.0 * &step * &k1)
319 + (-7200.0 / 2197.0 * &step * &k2)
320 + (7296.0 / 2197.0 * &step * &k3);
321 f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
322 };
323
324 let k5 = {
325 let x_step: Tensor = &x + &step;
326 let y_step: Tensor = &y
327 + (439.0 / 216.0 * &step * &k1)
328 + (-8.0 * &step * &k2)
329 + (3680.0 / 513.0 * &step * &k3)
330 + (-845.0 / 4104.0 * &step * &k4);
331 f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
332 };
333
334 let k6 = {
335 let x_step: Tensor = &x + 0.5 * &step;
336 let y_step: Tensor = &y
337 + (-8.0 / 27.0 * &step * &k1)
338 + (2.0 * &step * &k2)
339 + (-3544.0 / 2565.0 * &step * &k3)
340 + (1859.0 / 4104.0 * &step * &k4)
341 + (-11.0 / 40.0 * &step * &k5);
342 f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
343 };
344
345 let next_y4: Tensor = &y
346 + &step
347 * ((25.0 / 216.0 * &k1)
348 + (1408.0 / 2565.0 * &k3)
349 + (2197.0 / 4104.0 * &k4)
350 + (-1.0 / 5.0 * &k5));
351 let next_y5: Tensor = &y
352 + &step
353 * ((16.0 / 135.0 * &k1)
354 + (6656.0 / 12825.0 * &k3)
355 + (28561.0 / 56430.0 * &k4)
356 + (-9.0 / 50.0 * &k5)
357 + (2.0 / 55.0 * &k6));
358
359 let d = (&next_y4 - &next_y5).abs();
360 let e = next_y5.abs() * rtol + atol;
361
362 let alpha_tensor = (e / d).sqrt().min();
363 let condition = &safety_factor_tensor * &alpha_tensor;
364
365 let condition_met = condition.lt(1.0);
366 let condition_met_bool: bool = condition_met == Tensor::from(true);
367
368 if condition_met_bool {
369 step = &step * &condition;
370 if step.double_value(&[]) < min_step {
371 return Err(anyhow!("Required step is smaller than minimal step"));
372 }
373 } else {
374 y = next_y4;
375 x = &x + &step;
376 all_x.push(x.copy());
377 all_y.push(y.copy());
378
379 let new_step = &step * &condition;
380 let max_step = &step * 5.0;
381 step = new_step.fmin(&max_step);
382 }
383 }
384
385 Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
386}
387
388fn solve_row1(
390 f: tch::CModule,
391 x_span: Tensor,
392 y0: Tensor,
393 step: f64,
394) -> anyhow::Result<(Tensor, Tensor)> {
395 let x_start = x_span.i(0);
396 let x_end = x_span.i(1);
397
398 let mut x = x_start.unsqueeze(0);
399 let mut y = y0.unsqueeze(0);
400
401 let mut all_x = vec![x.copy()];
402 let mut all_y = vec![y.copy()];
403
404 while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
405 let remaining = &x_end - &x.squeeze();
406 let mut current_step = step;
407 if remaining.double_value(&[]) < step {
408 current_step = remaining.double_value(&[]);
409 }
410
411 let x_prev = x.copy();
412 let y_prev = y.copy().squeeze();
413
414 let jacobian = compute_jacobian(
415 |y| {
416 f.forward_ts(&[x_prev.squeeze().copy(), y.copy()])
417 .unwrap()
418 .squeeze()
419 },
420 &y_prev,
421 );
422 let f_current = f
423 .forward_ts(&[x_prev.squeeze().copy(), y_prev.copy()])?
424 .squeeze();
425
426 let n = jacobian.size()[0];
427 let eye = Tensor::eye(n, (tch::Kind::Float, jacobian.device()));
428 let step_j = current_step * &jacobian;
429 let inv_matrix = (eye - step_j).inverse();
430
431 let delta_y = inv_matrix.matmul(&f_current);
432 let y_next = y_prev + current_step * delta_y;
433
434 x = &x_prev + current_step;
435 y = y_next.unsqueeze(0);
436
437 all_x.push(x.copy());
438 all_y.push(y.copy());
439 }
440
441 Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
442}
443
444fn compute_jacobian<F>(f: F, x: &Tensor) -> Tensor
446where
447 F: Fn(&Tensor) -> Tensor,
448{
449 assert_eq!(x.dim(), 1, "x must be 1-dimensional");
450 let mut x_with_grad = x.detach().copy().set_requires_grad(true);
451 let y = f(&x_with_grad);
452 assert_eq!(y.dim(), 1, "y must be 1-dimensional");
453
454 let y_size = y.size()[0];
455 let mut grads = Vec::new();
456
457 for i in 0..y_size {
458 let yi = y.i(i);
459 let grad = Tensor::run_backward(&[yi], &[&x_with_grad], true, false)[0].copy();
462 grads.push(grad.unsqueeze(0));
463 x_with_grad.zero_grad();
464 }
465
466 Tensor::cat(&grads, 0)
467}