mini_ode/
lib.rs

1use anyhow::anyhow;
2use tch::IndexOp;
3use tch::Tensor;
4use std::sync::Arc;
5
6pub mod optimizers;
7
8pub enum Solver {
9    Euler { step: f64 },
10    RK4 { step: f64 },
11    ImplicitEuler { step: f64, optimizer: Arc<dyn optimizers::Optimizer> },
12    GLRK4 { step: f64, optimizer: Arc<dyn optimizers::Optimizer> },
13    RKF45 { rtol: f64, atol: f64, min_step: f64, safety_factor: f64 },
14    ROW1 { step: f64 }
15}
16
17impl Solver {
18    pub fn solve(
19        &self,
20        f: tch::CModule,
21        x_span: Tensor,
22        y0: Tensor
23    ) -> anyhow::Result<(Tensor, Tensor)> {
24        if x_span.size() != [2] {
25            return Err(anyhow!("x_span must be of shape [2] but it has shape {:?}", x_span.size().as_slice()));
26        }
27        if y0.size().len() != 1 {
28            return Err(anyhow!("y0 must be a one-dimensional tensor but it has {} dimensions", y0.size().len()));
29        }
30        if x_span.device() != y0.device() {
31            return Err(anyhow!("x_span and y0 must reside on the same device. Device of x_span is {:?}. Device of y0 is {:?}", x_span.device(), y0.device()));
32        }
33        if x_span.kind() != tch::Kind::Double && x_span.kind() != tch::Kind::Float && x_span.kind() != tch::Kind::BFloat16 && x_span.kind() != tch::Kind::Half {
34            return Err(anyhow!("x_span is of unsupported kind {:?}", x_span.kind()));
35        }
36        if y0.kind() != tch::Kind::Double && y0.kind() != tch::Kind::Float && y0.kind() != tch::Kind::BFloat16 && y0.kind() != tch::Kind::Half {
37            return Err(anyhow!("y0 is of unsupported kind {:?}", y0.kind()));
38        }
39        if x_span.kind() != y0.kind() {
40            return Err(anyhow!("x_span and y0 must be of the same kind. Kind of x_span is {:?}. Kind of y0 is {:?}", x_span.kind(), y0.kind()));
41        }
42
43        match self {
44            Self::Euler { step } => solve_euler(f, x_span, y0, *step),
45            Self::RK4 { step } => solve_rk4(f, x_span, y0, *step),
46            Self::ImplicitEuler { step, optimizer } => solve_implicit_euler(f, x_span, y0, *step, optimizer.as_ref()),
47            Self::GLRK4 { step, optimizer } => solve_glrk4(f, x_span, y0, *step, optimizer.as_ref()),
48            Self::RKF45 { rtol, atol, min_step, safety_factor } => solve_rkf45(f, x_span, y0, *rtol, *atol, *min_step, *safety_factor),
49            Self::ROW1 { step } => solve_row1(f, x_span, y0, *step)
50        }
51    }
52}
53
54/// Solves ODE using Euler method
55fn solve_euler(
56    f: tch::CModule,
57    x_span: Tensor,
58    y0: Tensor,
59    step: f64,
60) -> anyhow::Result<(Tensor, Tensor)> {
61    let x_start = x_span.i(0);
62    let x_end = x_span.i(1);
63
64    let mut x = x_start.unsqueeze(0);
65    let mut y = y0.unsqueeze(0);
66
67    let mut all_x = vec![x.copy()];
68    let mut all_y = vec![y.copy()];
69
70    let mut current_step = step;
71    while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
72        let remaining = &x_end - &x.squeeze();
73        if remaining.double_value(&[]) < current_step {
74            current_step = remaining.double_value(&[]);
75        }
76
77        let dy = f.forward_ts(&[x.squeeze().copy(), y.squeeze().copy()])?;
78        y = &y + current_step * &dy;
79        x = &x + current_step;
80
81        all_x.push(x.copy());
82        all_y.push(y.copy());
83    }
84
85    Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
86}
87
88/// Solves ODE using Runge-Kutta 4th order method
89fn solve_rk4(
90    f: tch::CModule,
91    x_span: Tensor,
92    y0: Tensor,
93    step: f64,
94) -> anyhow::Result<(Tensor, Tensor)> {
95    let x_start = x_span.i(0);
96    let x_end = x_span.i(1);
97
98    let mut x = x_start.unsqueeze(0);
99    let mut y = y0.unsqueeze(0);
100
101    let mut all_x = vec![x.copy()];
102    let mut all_y = vec![y.copy()];
103
104    let mut current_step = step;
105    while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
106        let remaining = &x_end - &x.squeeze();
107        if remaining.double_value(&[]) < current_step {
108            current_step = remaining.double_value(&[]);
109        }
110
111        let k1 = f.forward_ts(&[x.squeeze().copy(), y.squeeze().copy()])?;
112
113        let x_half: Tensor = &x + 0.5 * current_step;
114        let y_half: Tensor = &y + 0.5 * current_step * &k1;
115        let k2 = f.forward_ts(&[x_half.squeeze(), y_half.squeeze()])?;
116
117        let x_half_again: Tensor = &x + 0.5 * current_step;
118        let y_half_again: Tensor = &y + 0.5 * current_step * &k2;
119        let k3 = f.forward_ts(&[x_half_again.squeeze(), y_half_again.squeeze()])?;
120
121        let x_full = &x + current_step;
122        let y_full = &y + current_step * &k3;
123        let k4 = f.forward_ts(&[x_full.squeeze(), y_full.squeeze()])?;
124
125        let step_div_6 = current_step / 6.0;
126        let y_next = &y + step_div_6 * (&k1 + 2.0 * &k2 + 2.0 * &k3 + &k4);
127
128        x = &x + current_step;
129        y = y_next;
130
131        all_x.push(x.copy());
132        all_y.push(y.copy());
133    }
134
135    Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
136}
137
138/// Solves ODE using Implicit Euler method with gradient descent optimization
139fn solve_implicit_euler(
140    f: tch::CModule,
141    x_span: Tensor,
142    y0: Tensor,
143    step: f64,
144    optimizer: &dyn optimizers::Optimizer,
145) -> anyhow::Result<(Tensor, Tensor)> {
146    let x_start = x_span.i(0);
147    let x_end = x_span.i(1);
148
149    let mut x = x_start.unsqueeze(0);
150    let mut y = y0.unsqueeze(0);
151
152    let mut all_x = vec![x.copy()];
153    let mut all_y = vec![y.copy()];
154
155    let mut current_step = step;
156    while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
157        let remaining = &x_end - &x.squeeze();
158        if remaining.double_value(&[]) < current_step {
159            current_step = remaining.double_value(&[]);
160        }
161
162        let x_next = &x + current_step;
163        let y_prev = y.copy();
164
165        let y_next = optimizer.optimize(
166            &|y_next: &Tensor| {
167                let f_next = f
168                    .forward_ts(&[x_next.squeeze().copy(), y_next.squeeze().copy()])
169                    .unwrap();
170                let y_pred = &y_prev.squeeze() + current_step * &f_next;
171                (y_next - &y_pred).pow_tensor_scalar(2).sum(y_next.kind())
172            },
173            &(&y_prev.detach().squeeze()
174                + current_step * f.forward_ts(&[&x.squeeze(), &y_prev.squeeze()])?),
175        ).map_err( |err| {
176            anyhow!(format!("Optimizer failed with: {}", err))
177        })?;
178
179        y = y_next.unsqueeze(0);
180        x = x_next.copy();
181
182        all_x.push(x.copy());
183        all_y.push(y.copy());
184    }
185
186    Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
187}
188
189/// Solves ODE using Gauss-Legendre-Runge-Kutta 4th order method
190fn solve_glrk4(
191    f: tch::CModule,
192    x_span: Tensor,
193    y0: Tensor,
194    step: f64,
195    optimizer: &dyn optimizers::Optimizer,
196) -> anyhow::Result<(Tensor, Tensor)> {
197    let x_start = x_span.i(0);
198    let x_end = x_span.i(1);
199
200    let mut x = x_start.unsqueeze(0);
201    let mut y = y0.unsqueeze(0);
202
203    let mut all_x = vec![x.copy()];
204    let mut all_y = vec![y.copy()];
205
206    let mut current_step = step;
207    while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
208        let remaining = &x_end - &x.squeeze();
209        if remaining.double_value(&[]) < current_step {
210            current_step = remaining.double_value(&[]);
211        }
212
213        let k = f.forward_ts(&[x.squeeze().copy(), y.squeeze().copy()])?;
214
215        const C1: f64 = 0.2113248654f64;
216        const C2: f64 = 0.7886751346f64;
217        const A11: f64 = 0.25;
218        const A12: f64 = -0.03867513459f64;
219        const A21: f64 = 0.5386751346f64;
220        const A22: f64 = 0.25;
221
222        let first_k1k2_guess = Tensor::cat(
223            &[
224                f.forward_ts(&[
225                    &x.squeeze() + C1 * current_step,
226                    &y.squeeze() + C1 * current_step * &k,
227                ])?,
228                f.forward_ts(&[
229                    &x.squeeze() + C2 * current_step,
230                    &y.squeeze() + C2 * current_step * &k,
231                ])?,
232            ],
233            0,
234        );
235        let k1k2 = optimizer.optimize(
236            &|k1k2_guess| {
237                let diff1 = k1k2_guess.i(0..=1)
238                    - f.forward_ts(&[
239                        &x.squeeze() + C1 * current_step,
240                        &y.squeeze()
241                            + (A11 * k1k2_guess.i(0..=1) + A12 * k1k2_guess.i(2..=3))
242                                * current_step,
243                    ])
244                    .unwrap();
245                let diff2 = k1k2_guess.i(2..=3)
246                    - f.forward_ts(&[
247                        &x.squeeze() + C2 * current_step,
248                        &y.squeeze()
249                            + (A21 * k1k2_guess.i(0..=1) + A22 * k1k2_guess.i(2..=3))
250                                * current_step,
251                    ])
252                    .unwrap();
253
254                diff1.dot(&diff1) + diff2.dot(&diff2)
255            },
256            &first_k1k2_guess,
257        ).map_err( |err| {
258            anyhow!(format!("Optimizer failed with: {}", err))
259        })?;
260        assert!(k1k2.size().len() == 1);
261        assert!(k1k2.size()[0] == 4);
262
263        x = &x + current_step;
264        y = &y + current_step * (0.5 * k1k2.i(0..=1) + 0.5 * k1k2.i(2..=3));
265
266        all_x.push(x.copy());
267        all_y.push(y.copy());
268    }
269
270    Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
271}
272
273/// Solves ODE using Runge-Kutta-Fehlberg 45 adaptive method
274fn solve_rkf45(
275    f: tch::CModule,
276    x_span: Tensor,
277    y0: Tensor,
278    rtol: f64,
279    atol: f64,
280    min_step: f64,
281    safety_factor: f64,
282) -> anyhow::Result<(Tensor, Tensor)> {
283    let x_start = x_span.i(0);
284    let x_end = x_span.i(1);
285
286    let mut x = x_start.unsqueeze(0);
287    let mut y = y0.unsqueeze(0);
288
289    let mut all_x = vec![x.copy()];
290    let mut all_y = vec![y.copy()];
291
292    let mut step = (&x_end - &x_start) * 0.1;
293    let safety_factor_tensor = Tensor::from(safety_factor);
294
295    while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
296        let remaining = &x_end - &x.squeeze();
297        if remaining.lt_tensor(&step) == Tensor::from(true) {
298            step = remaining.copy();
299        }
300
301        let k1 = f.forward_ts(&[x.squeeze().copy(), y.squeeze().copy()])?;
302
303        let k2 = {
304            let x_step: Tensor = &x + 0.25 * &step;
305            let y_step: Tensor = &y + 0.25 * &step * &k1;
306            f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
307        };
308
309        let k3 = {
310            let x_step: Tensor = &x + 0.375 * &step;
311            let y_step: Tensor = &y + (0.09375 * &step * &k1) + (0.28125 * &step * &k2);
312            f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
313        };
314
315        let k4 = {
316            let x_step: Tensor = &x + (12.0 / 13.0) * &step;
317            let y_step: Tensor = &y
318                + (1932.0 / 2197.0 * &step * &k1)
319                + (-7200.0 / 2197.0 * &step * &k2)
320                + (7296.0 / 2197.0 * &step * &k3);
321            f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
322        };
323
324        let k5 = {
325            let x_step: Tensor = &x + &step;
326            let y_step: Tensor = &y
327                + (439.0 / 216.0 * &step * &k1)
328                + (-8.0 * &step * &k2)
329                + (3680.0 / 513.0 * &step * &k3)
330                + (-845.0 / 4104.0 * &step * &k4);
331            f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
332        };
333
334        let k6 = {
335            let x_step: Tensor = &x + 0.5 * &step;
336            let y_step: Tensor = &y
337                + (-8.0 / 27.0 * &step * &k1)
338                + (2.0 * &step * &k2)
339                + (-3544.0 / 2565.0 * &step * &k3)
340                + (1859.0 / 4104.0 * &step * &k4)
341                + (-11.0 / 40.0 * &step * &k5);
342            f.forward_ts(&[x_step.squeeze(), y_step.squeeze()])?
343        };
344
345        let next_y4: Tensor = &y
346            + &step
347                * ((25.0 / 216.0 * &k1)
348                    + (1408.0 / 2565.0 * &k3)
349                    + (2197.0 / 4104.0 * &k4)
350                    + (-1.0 / 5.0 * &k5));
351        let next_y5: Tensor = &y
352            + &step
353                * ((16.0 / 135.0 * &k1)
354                    + (6656.0 / 12825.0 * &k3)
355                    + (28561.0 / 56430.0 * &k4)
356                    + (-9.0 / 50.0 * &k5)
357                    + (2.0 / 55.0 * &k6));
358
359        let d = (&next_y4 - &next_y5).abs();
360        let e = next_y5.abs() * rtol + atol;
361
362        let alpha_tensor = (e / d).sqrt().min();
363        let condition = &safety_factor_tensor * &alpha_tensor;
364
365        let condition_met = condition.lt(1.0);
366        let condition_met_bool: bool = condition_met == Tensor::from(true);
367
368        if condition_met_bool {
369            step = &step * &condition;
370            if step.double_value(&[]) < min_step {
371                return Err(anyhow!("Required step is smaller than minimal step"));
372            }
373        } else {
374            y = next_y4;
375            x = &x + &step;
376            all_x.push(x.copy());
377            all_y.push(y.copy());
378
379            let new_step = &step * &condition;
380            let max_step = &step * 5.0;
381            step = new_step.fmin(&max_step);
382        }
383    }
384
385    Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
386}
387
388/// Solves ODE using first-order Rosenbrock method (Row1)
389fn solve_row1(
390    f: tch::CModule,
391    x_span: Tensor,
392    y0: Tensor,
393    step: f64,
394) -> anyhow::Result<(Tensor, Tensor)> {
395    let x_start = x_span.i(0);
396    let x_end = x_span.i(1);
397
398    let mut x = x_start.unsqueeze(0);
399    let mut y = y0.unsqueeze(0);
400
401    let mut all_x = vec![x.copy()];
402    let mut all_y = vec![y.copy()];
403
404    while x.lt_tensor(&x_end) == Tensor::from_slice(&[true]) {
405        let remaining = &x_end - &x.squeeze();
406        let mut current_step = step;
407        if remaining.double_value(&[]) < step {
408            current_step = remaining.double_value(&[]);
409        }
410
411        let x_prev = x.copy();
412        let y_prev = y.copy().squeeze();
413
414        let jacobian = compute_jacobian(
415            |y| {
416                f.forward_ts(&[x_prev.squeeze().copy(), y.copy()])
417                    .unwrap()
418                    .squeeze()
419            },
420            &y_prev,
421        );
422        let f_current = f
423            .forward_ts(&[x_prev.squeeze().copy(), y_prev.copy()])?
424            .squeeze();
425
426        let n = jacobian.size()[0];
427        let eye = Tensor::eye(n, (tch::Kind::Float, jacobian.device()));
428        let step_j = current_step * &jacobian;
429        let inv_matrix = (eye - step_j).inverse();
430
431        let delta_y = inv_matrix.matmul(&f_current);
432        let y_next = y_prev + current_step * delta_y;
433
434        x = &x_prev + current_step;
435        y = y_next.unsqueeze(0);
436
437        all_x.push(x.copy());
438        all_y.push(y.copy());
439    }
440
441    Ok((Tensor::cat(&all_x, 0), Tensor::cat(&all_y, 0)))
442}
443
444/// Computes the Jacobian matrix of a function f at point x
445fn compute_jacobian<F>(f: F, x: &Tensor) -> Tensor
446where
447    F: Fn(&Tensor) -> Tensor,
448{
449    assert_eq!(x.dim(), 1, "x must be 1-dimensional");
450    let mut x_with_grad = x.detach().copy().set_requires_grad(true);
451    let y = f(&x_with_grad);
452    assert_eq!(y.dim(), 1, "y must be 1-dimensional");
453
454    let y_size = y.size()[0];
455    let mut grads = Vec::new();
456
457    for i in 0..y_size {
458        let yi = y.i(i);
459        //yi.backward();
460        //let grad = x_with_grad.grad().copy();
461        let grad = Tensor::run_backward(&[yi], &[&x_with_grad], true, false)[0].copy();
462        grads.push(grad.unsqueeze(0));
463        x_with_grad.zero_grad();
464    }
465
466    Tensor::cat(&grads, 0)
467}