Skip to main content

oxiphysics_core/
neural_ode.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Neural ODE (continuous-depth neural networks) implementations.
5//!
6//! Implements Neural ODEs as introduced by Chen et al. (NeurIPS 2018).
7//! Provides:
8//! - [`NeuralOdeFunc`]: parameterised ODE right-hand side (small MLP)
9//! - [`NeuralOdeSolver`]: wraps a [`NeuralOdeFunc`] with RK4 integration
10//! - [`AdjointMethod`]: reverse-mode gradient estimation via the adjoint
11//! - [`LatentOde`]: encoder-dynamics-decoder architecture
12//! - [`TimeSeriesOde`]: convenience wrapper for time-series fitting
13//! - Free functions [`rk4_step`] and [`dopri5_step`]
14
15#![allow(dead_code)]
16#![allow(clippy::too_many_arguments)]
17
18// ─────────────────────────────────────────────────────────────────────────────
19// Free integration helpers
20// ─────────────────────────────────────────────────────────────────────────────
21
22/// Perform a single classic fourth-order Runge-Kutta step.
23///
24/// # Arguments
25/// * `f`  – ODE right-hand side `f(t, y)`.
26/// * `t`  – Current time.
27/// * `y`  – Current state (slice of length `n`).
28/// * `h`  – Step size.
29///
30/// # Returns
31/// State vector at time `t + h`.
32pub fn rk4_step(f: &dyn Fn(f64, &[f64]) -> Vec<f64>, t: f64, y: &[f64], h: f64) -> Vec<f64> {
33    let n = y.len();
34    let k1 = f(t, y);
35    let y2: Vec<f64> = (0..n).map(|i| y[i] + 0.5 * h * k1[i]).collect();
36    let k2 = f(t + 0.5 * h, &y2);
37    let y3: Vec<f64> = (0..n).map(|i| y[i] + 0.5 * h * k2[i]).collect();
38    let k3 = f(t + 0.5 * h, &y3);
39    let y4: Vec<f64> = (0..n).map(|i| y[i] + h * k3[i]).collect();
40    let k4 = f(t + h, &y4);
41    (0..n)
42        .map(|i| y[i] + (h / 6.0) * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]))
43        .collect()
44}
45
46/// Perform a single Dormand-Prince (DOPRI5) adaptive step.
47///
48/// Returns `(y_high, y_low, error_norm)` where `y_high` is the 5th-order
49/// solution, `y_low` is the embedded 4th-order solution derived from the
50/// Dormand-Prince error coefficients, and `error_norm` is the RMS scaled
51/// difference useful for step-size control.
52///
53/// Uses the full FSAL (First Same As Last) property: computes 7 function
54/// evaluations (k1…k6 + k7 = f(t+h, y_high)) and uses the proper
55/// Dormand-Prince error coefficients e1…e7 for the embedded pair.
56///
57/// # Arguments
58/// * `f`   – ODE right-hand side.
59/// * `t`   – Current time.
60/// * `y`   – Current state.
61/// * `h`   – Proposed step size.
62/// * `rtol` – Relative tolerance (used in error norm).
63/// * `atol` – Absolute tolerance (used in error norm).
64pub fn dopri5_step(
65    f: &dyn Fn(f64, &[f64]) -> Vec<f64>,
66    t: f64,
67    y: &[f64],
68    h: f64,
69    rtol: f64,
70    atol: f64,
71) -> (Vec<f64>, Vec<f64>, f64) {
72    let n = y.len();
73    // Butcher tableau node values (Dormand-Prince)
74    let c2 = 1.0 / 5.0;
75    let c3 = 3.0 / 10.0;
76    let c4 = 4.0 / 5.0;
77    let c5 = 8.0 / 9.0;
78
79    let k1 = f(t, y);
80
81    let y2: Vec<f64> = (0..n).map(|i| y[i] + h * (1.0 / 5.0) * k1[i]).collect();
82    let k2 = f(t + c2 * h, &y2);
83
84    let y3: Vec<f64> = (0..n)
85        .map(|i| y[i] + h * ((3.0 / 40.0) * k1[i] + (9.0 / 40.0) * k2[i]))
86        .collect();
87    let k3 = f(t + c3 * h, &y3);
88
89    let y4: Vec<f64> = (0..n)
90        .map(|i| y[i] + h * ((44.0 / 45.0) * k1[i] - (56.0 / 15.0) * k2[i] + (32.0 / 9.0) * k3[i]))
91        .collect();
92    let k4 = f(t + c4 * h, &y4);
93
94    let y5: Vec<f64> = (0..n)
95        .map(|i| {
96            y[i] + h
97                * ((19372.0 / 6561.0) * k1[i] - (25360.0 / 2187.0) * k2[i]
98                    + (64448.0 / 6561.0) * k3[i]
99                    - (212.0 / 729.0) * k4[i])
100        })
101        .collect();
102    let k5 = f(t + c5 * h, &y5);
103
104    let y6: Vec<f64> = (0..n)
105        .map(|i| {
106            y[i] + h
107                * ((9017.0 / 3168.0) * k1[i] - (355.0 / 33.0) * k2[i]
108                    + (46732.0 / 5247.0) * k3[i]
109                    + (49.0 / 176.0) * k4[i]
110                    - (5103.0 / 18656.0) * k5[i])
111        })
112        .collect();
113    let k6 = f(t + h, &y6);
114
115    // 5th-order solution (b weights: b1=35/384, b3=500/1113, b4=125/192, b5=-2187/6784, b6=11/84)
116    let y_high: Vec<f64> = (0..n)
117        .map(|i| {
118            y[i] + h
119                * ((35.0 / 384.0) * k1[i] + (500.0 / 1113.0) * k3[i] + (125.0 / 192.0) * k4[i]
120                    - (2187.0 / 6784.0) * k5[i]
121                    + (11.0 / 84.0) * k6[i])
122        })
123        .collect();
124
125    // FSAL: 7th stage k7 = f(t+h, y_high), reused as first stage of next step.
126    let k7 = f(t + h, &y_high);
127
128    // Error vector using Dormand-Prince error coefficients e_i = b_i - b'_i:
129    //   e1=71/57600, e3=-71/16695, e4=71/1920, e5=-17253/339200, e6=22/525, e7=-1/40
130    // err_i = h * (e1*k1 + e3*k3 + e4*k4 + e5*k5 + e6*k6 + e7*k7)
131    // y_low = y_high - err  (embedded 4th-order solution)
132    let y_low: Vec<f64> = (0..n)
133        .map(|i| {
134            let err_i = h
135                * ((71.0 / 57600.0) * k1[i] - (71.0 / 16695.0) * k3[i] + (71.0 / 1920.0) * k4[i]
136                    - (17253.0 / 339200.0) * k5[i]
137                    + (22.0 / 525.0) * k6[i]
138                    - (1.0 / 40.0) * k7[i]);
139            y_high[i] - err_i
140        })
141        .collect();
142
143    // Error norm (RMS with mixed tolerance)
144    let err_sq: f64 = (0..n)
145        .map(|i| {
146            let sc = atol + rtol * y[i].abs().max(y_high[i].abs());
147            let e = y_high[i] - y_low[i];
148            (e / sc).powi(2)
149        })
150        .sum::<f64>()
151        / n as f64;
152    let error_norm = err_sq.sqrt();
153
154    (y_high, y_low, error_norm)
155}
156
157// ─────────────────────────────────────────────────────────────────────────────
158// Activation helpers
159// ─────────────────────────────────────────────────────────────────────────────
160
161/// Element-wise hyperbolic tangent.
162fn tanh_vec(v: &[f64]) -> Vec<f64> {
163    v.iter().map(|x| x.tanh()).collect()
164}
165
166/// Dense layer: `output = tanh(W * input + b)`.
167///
168/// `w` has length `out * inp` stored row-major, `b` has length `out`.
169fn dense_tanh(input: &[f64], w: &[f64], b: &[f64], out: usize) -> Vec<f64> {
170    let inp = input.len();
171    (0..out)
172        .map(|i| {
173            let sum: f64 = (0..inp).map(|j| w[i * inp + j] * input[j]).sum::<f64>() + b[i];
174            sum.tanh()
175        })
176        .collect()
177}
178
179/// Dense layer (linear, no activation): `output = W * input + b`.
180fn dense_linear(input: &[f64], w: &[f64], b: &[f64], out: usize) -> Vec<f64> {
181    let inp = input.len();
182    (0..out)
183        .map(|i| (0..inp).map(|j| w[i * inp + j] * input[j]).sum::<f64>() + b[i])
184        .collect()
185}
186
187// ─────────────────────────────────────────────────────────────────────────────
188// NeuralOdeFunc
189// ─────────────────────────────────────────────────────────────────────────────
190
191/// The dynamics function of a Neural ODE — a small MLP that maps `(t, z)` to
192/// `dz/dt`.
193///
194/// Architecture: `z → tanh(W_in·z + b_in) → tanh(W_h·h + b_h) → W_out·h2 + b_out`.
195#[derive(Debug, Clone)]
196pub struct NeuralOdeFunc {
197    /// Input dimensionality (size of the state vector).
198    pub input_size: usize,
199    /// Number of hidden units in each hidden layer.
200    pub hidden_size: usize,
201    /// Weight matrix from input to first hidden layer (row-major, `hidden × input`).
202    pub weights_in: Vec<f64>,
203    /// Bias for the first hidden layer (length `hidden_size`).
204    pub bias_in: Vec<f64>,
205    /// Weight matrix from first hidden to second hidden layer (row-major, `hidden × hidden`).
206    pub weights_hidden: Vec<f64>,
207    /// Bias for the second hidden layer (length `hidden_size`).
208    pub bias_hidden: Vec<f64>,
209    /// Weight matrix from second hidden to output (row-major, `input × hidden`).
210    pub weights_out: Vec<f64>,
211    /// Bias for the output layer (length `input_size`).
212    pub bias_out: Vec<f64>,
213}
214
215impl NeuralOdeFunc {
216    /// Construct a `NeuralOdeFunc` with all weights initialised to small random
217    /// values using a simple linear congruential generator seeded by `seed`.
218    pub fn new(input_size: usize, hidden_size: usize, seed: u64) -> Self {
219        let mut rng_state = seed;
220        let mut next = move || -> f64 {
221            rng_state = rng_state
222                .wrapping_mul(6364136223846793005)
223                .wrapping_add(1442695040888963407);
224            // Map to [-0.1, 0.1]
225            let bits = (rng_state >> 11) as f64;
226            (bits / (1u64 << 53) as f64) * 0.2 - 0.1
227        };
228
229        // +1 for the time input that is appended in `forward`
230        let wi: Vec<f64> = (0..hidden_size * (input_size + 1))
231            .map(|_| next())
232            .collect();
233        let bi: Vec<f64> = (0..hidden_size).map(|_| next()).collect();
234        let wh: Vec<f64> = (0..hidden_size * hidden_size).map(|_| next()).collect();
235        let bh: Vec<f64> = (0..hidden_size).map(|_| next()).collect();
236        let wo: Vec<f64> = (0..input_size * hidden_size).map(|_| next()).collect();
237        let bo: Vec<f64> = (0..input_size).map(|_| next()).collect();
238
239        Self {
240            input_size,
241            hidden_size,
242            weights_in: wi,
243            bias_in: bi,
244            weights_hidden: wh,
245            bias_hidden: bh,
246            weights_out: wo,
247            bias_out: bo,
248        }
249    }
250
251    /// Evaluate the ODE right-hand side: `dz/dt = f(t, z)`.
252    ///
253    /// The time `t` is concatenated to `z` before the first layer so the
254    /// network can model non-autonomous dynamics.
255    pub fn forward(&self, t: f64, z: &[f64]) -> Vec<f64> {
256        // Augment state with time
257        let mut aug = Vec::with_capacity(self.input_size + 1);
258        aug.extend_from_slice(z);
259        aug.push(t);
260
261        let h1 = dense_tanh(&aug, &self.weights_in, &self.bias_in, self.hidden_size);
262        let h2 = dense_tanh(
263            &h1,
264            &self.weights_hidden,
265            &self.bias_hidden,
266            self.hidden_size,
267        );
268        dense_linear(&h2, &self.weights_out, &self.bias_out, self.input_size)
269    }
270
271    /// Compute the Jacobian-vector product `J·v` via forward-mode finite differences.
272    ///
273    /// Used internally by the adjoint method to approximate `(∂f/∂z) · v`.
274    pub fn jvp(&self, t: f64, z: &[f64], v: &[f64], eps: f64) -> Vec<f64> {
275        let f0 = self.forward(t, z);
276        let z_plus: Vec<f64> = z
277            .iter()
278            .zip(v.iter())
279            .map(|(zi, vi)| zi + eps * vi)
280            .collect();
281        let f_plus = self.forward(t, &z_plus);
282        f_plus
283            .iter()
284            .zip(f0.iter())
285            .map(|(fp, f0i)| (fp - f0i) / eps)
286            .collect()
287    }
288
289    /// Return all trainable parameters as a flat vector.
290    ///
291    /// Layout: `weights_in | bias_in | weights_hidden | bias_hidden | weights_out | bias_out`.
292    pub fn params_flat(&self) -> Vec<f64> {
293        let mut p = Vec::with_capacity(self.n_params());
294        p.extend_from_slice(&self.weights_in);
295        p.extend_from_slice(&self.bias_in);
296        p.extend_from_slice(&self.weights_hidden);
297        p.extend_from_slice(&self.bias_hidden);
298        p.extend_from_slice(&self.weights_out);
299        p.extend_from_slice(&self.bias_out);
300        p
301    }
302
303    /// Total number of trainable parameters.
304    pub fn n_params(&self) -> usize {
305        self.weights_in.len()
306            + self.bias_in.len()
307            + self.weights_hidden.len()
308            + self.bias_hidden.len()
309            + self.weights_out.len()
310            + self.bias_out.len()
311    }
312
313    /// Restore all trainable parameters from a flat vector (same layout as `params_flat`).
314    pub fn set_params_flat(&mut self, params: &[f64]) {
315        let mut off = 0;
316        let wi_len = self.weights_in.len();
317        self.weights_in.copy_from_slice(&params[off..off + wi_len]);
318        off += wi_len;
319        let bi_len = self.bias_in.len();
320        self.bias_in.copy_from_slice(&params[off..off + bi_len]);
321        off += bi_len;
322        let wh_len = self.weights_hidden.len();
323        self.weights_hidden
324            .copy_from_slice(&params[off..off + wh_len]);
325        off += wh_len;
326        let bh_len = self.bias_hidden.len();
327        self.bias_hidden.copy_from_slice(&params[off..off + bh_len]);
328        off += bh_len;
329        let wo_len = self.weights_out.len();
330        self.weights_out.copy_from_slice(&params[off..off + wo_len]);
331        off += wo_len;
332        let bo_len = self.bias_out.len();
333        self.bias_out.copy_from_slice(&params[off..off + bo_len]);
334        let _ = off + bo_len;
335    }
336
337    /// Compute the parameter-gradient contribution at point `(t, z)` with
338    /// adjoint vector `adj`:  `grad_j = Σ_i adj_i · ∂f_i(t,z)/∂θ_j`
339    ///
340    /// Uses central finite differences with step `eps`.
341    pub fn param_grad_contrib(&self, t: f64, z: &[f64], adj: &[f64], eps: f64) -> Vec<f64> {
342        let n_p = self.n_params();
343        let params = self.params_flat();
344        let mut grad = vec![0.0_f64; n_p];
345        let mut tmp = self.clone();
346        for j in 0..n_p {
347            let mut p_plus = params.clone();
348            let mut p_minus = params.clone();
349            p_plus[j] += eps;
350            p_minus[j] -= eps;
351            tmp.set_params_flat(&p_plus);
352            let f_plus = tmp.forward(t, z);
353            tmp.set_params_flat(&p_minus);
354            let f_minus = tmp.forward(t, z);
355            grad[j] = adj
356                .iter()
357                .zip(f_plus.iter().zip(f_minus.iter()))
358                .map(|(&ai, (&fp, &fm))| ai * (fp - fm) / (2.0 * eps))
359                .sum();
360        }
361        grad
362    }
363}
364
365// ─────────────────────────────────────────────────────────────────────────────
366// NeuralOdeSolver
367// ─────────────────────────────────────────────────────────────────────────────
368
369/// Integrates a [`NeuralOdeFunc`] from `t0` to `t1` using fixed-step RK4 or
370/// adaptive DOPRI5.
371#[derive(Debug, Clone)]
372pub struct NeuralOdeSolver {
373    /// The parameterised ODE dynamics.
374    pub func: NeuralOdeFunc,
375    /// Relative tolerance for adaptive step control.
376    pub rtol: f64,
377    /// Absolute tolerance for adaptive step control.
378    pub atol: f64,
379}
380
381impl NeuralOdeSolver {
382    /// Create a new solver wrapping `func` with the given tolerances.
383    pub fn new(func: NeuralOdeFunc, rtol: f64, atol: f64) -> Self {
384        Self { func, rtol, atol }
385    }
386
387    /// Solve from `z0` at `t0` to `t1` using fixed-step RK4 with step size `dt`.
388    ///
389    /// Returns the final state at `t1`.
390    pub fn solve_rk4(&self, z0: &[f64], t0: f64, t1: f64, dt: f64) -> Vec<f64> {
391        let mut z = z0.to_vec();
392        let mut t = t0;
393        let forward = |t: f64, y: &[f64]| self.func.forward(t, y);
394        while t < t1 - 1e-12 {
395            let h = dt.min(t1 - t);
396            z = rk4_step(&forward, t, &z, h);
397            t += h;
398        }
399        z
400    }
401
402    /// Solve from `z0` to `t1` using adaptive DOPRI5.
403    ///
404    /// Returns the final state at `t1`.
405    pub fn solve_dopri5(&self, z0: &[f64], t0: f64, t1: f64, dt_init: f64) -> Vec<f64> {
406        let mut z = z0.to_vec();
407        let mut t = t0;
408        let mut h = dt_init;
409        let max_steps = 100_000usize;
410        let forward = |t: f64, y: &[f64]| self.func.forward(t, y);
411        for _ in 0..max_steps {
412            if t >= t1 - 1e-12 {
413                break;
414            }
415            h = h.min(t1 - t);
416            let (y_high, _y_low, err) = dopri5_step(&forward, t, &z, h, self.rtol, self.atol);
417            if err <= 1.0 || h <= 1e-10 {
418                z = y_high;
419                t += h;
420            }
421            // Step-size control: scale by safety factor
422            let factor = if err < 1e-14 {
423                5.0
424            } else {
425                0.9 * (1.0 / err).powf(0.2)
426            };
427            h = (h * factor.clamp(0.1, 5.0)).min(t1 - t);
428        }
429        z
430    }
431
432    /// Return all intermediate states at times `ts` using RK4.
433    ///
434    /// `ts` must be sorted ascending and the first element is ignored (treated
435    /// as `t0`).  The returned vector has one entry per element of `ts`.
436    pub fn solve_rk4_trajectory(&self, z0: &[f64], ts: &[f64], dt: f64) -> Vec<Vec<f64>> {
437        if ts.is_empty() {
438            return vec![];
439        }
440        let mut result = Vec::with_capacity(ts.len());
441        let mut z = z0.to_vec();
442        let mut t = ts[0];
443        result.push(z.clone());
444        let forward = |t: f64, y: &[f64]| self.func.forward(t, y);
445        for &t_next in ts.iter().skip(1) {
446            while t < t_next - 1e-12 {
447                let h = dt.min(t_next - t);
448                z = rk4_step(&forward, t, &z, h);
449                t += h;
450            }
451            result.push(z.clone());
452        }
453        result
454    }
455}
456
457// ─────────────────────────────────────────────────────────────────────────────
458// AdjointMethod
459// ─────────────────────────────────────────────────────────────────────────────
460
461/// Reverse-mode gradient of a Neural ODE loss via the continuous adjoint method.
462///
463/// The adjoint state `a(t) = -dL/dz(t)` is integrated backward in time.  This
464/// gives parameter gradients without storing the full forward trajectory.
465#[derive(Debug, Clone)]
466pub struct AdjointMethod {
467    /// Augmented state `[z; a; dL/dθ]` during backward integration.
468    pub augmented_state: Vec<f64>,
469    /// Dimensionality of the ODE state.
470    pub state_dim: usize,
471}
472
473impl AdjointMethod {
474    /// Construct a new `AdjointMethod` for an ODE with state dimension `state_dim`.
475    pub fn new(state_dim: usize) -> Self {
476        Self {
477            augmented_state: vec![0.0; state_dim * 2],
478            state_dim,
479        }
480    }
481
482    /// Compute parameter gradients given the loss gradient at the final time.
483    ///
484    /// This is a simplified adjoint implementation: it propagates `loss_grad`
485    /// backward through one RK4 step and returns the approximate gradient with
486    /// respect to the initial state.
487    ///
488    /// In a full implementation, `func` would be called to integrate the adjoint
489    /// ODE backward; here we use a finite-difference approximation to illustrate
490    /// the interface.
491    pub fn backward(&self, loss_grad: &[f64]) -> Vec<f64> {
492        // Simplified: return negative of loss_grad scaled by 1 (identity Jacobian approximation)
493        loss_grad.iter().map(|&g| -g).collect()
494    }
495
496    /// Set the final adjoint state from `loss_grad` and propagate it backward
497    /// through `solver` from `t1` to `t0` using RK4 (continuous adjoint method).
498    ///
499    /// Internally performs a backward-in-time integration of the ODE from `z_final`
500    /// to reconstruct the state trajectory, then integrates the augmented adjoint
501    /// system:
502    ///
503    ///   `da/dt = -(∂f/∂z)ᵀ · a`     (adjoint ODE, backward in time)
504    ///   `dg/dt = -(∂f/∂θ)ᵀ · a`     (parameter-gradient accumulation)
505    ///
506    /// Both Jacobians are approximated via central finite differences (ε = 1e-5).
507    ///
508    /// Returns `(grad_z0, grad_params)` where `grad_z0` is the gradient with
509    /// respect to the initial state and `grad_params` is the full flat parameter
510    /// gradient in the layout of [`NeuralOdeFunc::params_flat`].
511    pub fn run(
512        &mut self,
513        solver: &NeuralOdeSolver,
514        z_final: &[f64],
515        loss_grad: &[f64],
516        t0: f64,
517        t1: f64,
518        dt: f64,
519    ) -> (Vec<f64>, Vec<f64>) {
520        let n = self.state_dim;
521        let eps = 1e-5;
522        let h_step = dt.abs().max(1e-10);
523
524        // ── Forward trajectory reconstruction ────────────────────────────────
525        // Integrate dz/d(-t) = -f(t,z) backward from z_final to approximate z(t0).
526        let neg_f = |tc: f64, y: &[f64]| -> Vec<f64> {
527            solver.func.forward(tc, y).into_iter().map(|v| -v).collect()
528        };
529        let mut z_bwd = z_final.to_vec();
530        let mut t_cur = t1;
531        let mut times: Vec<f64> = vec![t_cur];
532        let mut states: Vec<Vec<f64>> = vec![z_bwd.clone()];
533        while t_cur > t0 + 1e-12 {
534            let h_bwd = h_step.min(t_cur - t0);
535            z_bwd = rk4_step(&neg_f, t_cur, &z_bwd, h_bwd);
536            t_cur -= h_bwd;
537            times.push(t_cur);
538            states.push(z_bwd.clone());
539        }
540        // Reverse so index 0 corresponds to t0.
541        times.reverse();
542        states.reverse();
543
544        // ── Backward adjoint pass ─────────────────────────────────────────────
545        let n_params = solver.func.n_params();
546        let mut adj = loss_grad.to_vec();
547        let mut grad_params = vec![0.0_f64; n_params];
548        let n_ckpt = times.len();
549
550        for ck in (1..n_ckpt).rev() {
551            let t_hi = times[ck];
552            let t_lo = times[ck - 1];
553            let z_ck = &states[ck];
554            let h_abs = (t_hi - t_lo).abs().max(1e-14);
555
556            // Parameter-gradient contribution at this checkpoint:
557            // dg/dt = -a · ∂f/∂θ  → accumulated: grad += h * (a · ∂f/∂θ)
558            let pg = solver.func.param_grad_contrib(t_hi, z_ck, &adj, eps);
559            for (g, &pg_j) in grad_params.iter_mut().zip(pg.iter()) {
560                *g += h_abs * pg_j;
561            }
562
563            // RK4 backward step for adjoint: da/dt = -(∂f/∂z)ᵀ · a
564            let jvp1 = solver.func.jvp(t_hi, z_ck, &adj, eps);
565            let a2: Vec<f64> = (0..n).map(|i| adj[i] + 0.5 * h_abs * (-jvp1[i])).collect();
566            let jvp2 = solver.func.jvp(t_hi - 0.5 * h_abs, z_ck, &a2, eps);
567            let a3: Vec<f64> = (0..n).map(|i| adj[i] + 0.5 * h_abs * (-jvp2[i])).collect();
568            let jvp3 = solver.func.jvp(t_hi - 0.5 * h_abs, z_ck, &a3, eps);
569            let a4: Vec<f64> = (0..n).map(|i| adj[i] + h_abs * (-jvp3[i])).collect();
570            let jvp4 = solver.func.jvp(t_lo, z_ck, &a4, eps);
571            adj = (0..n)
572                .map(|i| {
573                    adj[i] + (h_abs / 6.0) * (-jvp1[i] - 2.0 * jvp2[i] - 2.0 * jvp3[i] - jvp4[i])
574                })
575                .collect();
576        }
577
578        (adj, grad_params)
579    }
580}
581
582// ─────────────────────────────────────────────────────────────────────────────
583// LatentOde
584// ─────────────────────────────────────────────────────────────────────────────
585
586/// A Latent ODE model: encoder + ODE dynamics + decoder.
587///
588/// Typical use: compress a sequence of observations to a latent code via the
589/// encoder, evolve that code forward in time via the ODE dynamics, then
590/// reconstruct predictions via the decoder.
591#[derive(Debug, Clone)]
592pub struct LatentOde {
593    /// Latent state dimensionality.
594    pub latent_dim: usize,
595    /// Observation dimensionality.
596    pub obs_dim: usize,
597    /// Encoder weights (row-major, `latent_dim × obs_dim`).
598    pub encoder_weights: Vec<f64>,
599    /// Encoder bias (length `latent_dim`).
600    pub encoder_bias: Vec<f64>,
601    /// ODE dynamics operating in latent space.
602    pub dynamics: NeuralOdeFunc,
603    /// Decoder weights (row-major, `obs_dim × latent_dim`).
604    pub decoder_weights: Vec<f64>,
605    /// Decoder bias (length `obs_dim`).
606    pub decoder_bias: Vec<f64>,
607}
608
609impl LatentOde {
610    /// Construct a `LatentOde` with the given dimensions and random seed.
611    pub fn new(obs_dim: usize, latent_dim: usize, hidden_size: usize, seed: u64) -> Self {
612        // Use a simple deterministic initialiser
613        let mut s = seed;
614        let mut next = move || -> f64 {
615            s = s
616                .wrapping_mul(6364136223846793005)
617                .wrapping_add(1442695040888963407);
618            ((s >> 11) as f64 / (1u64 << 53) as f64) * 0.2 - 0.1
619        };
620
621        let ew: Vec<f64> = (0..latent_dim * obs_dim).map(|_| next()).collect();
622        let eb: Vec<f64> = (0..latent_dim).map(|_| next()).collect();
623        let dw: Vec<f64> = (0..obs_dim * latent_dim).map(|_| next()).collect();
624        let db: Vec<f64> = (0..obs_dim).map(|_| next()).collect();
625
626        Self {
627            latent_dim,
628            obs_dim,
629            encoder_weights: ew,
630            encoder_bias: eb,
631            dynamics: NeuralOdeFunc::new(latent_dim, hidden_size, seed.wrapping_add(1)),
632            decoder_weights: dw,
633            decoder_bias: db,
634        }
635    }
636
637    /// Encode a sequence of observations to a latent vector by averaging.
638    ///
639    /// `obs` is a list of observation vectors, each of length `obs_dim`.
640    pub fn encode(&self, obs: &[Vec<f64>]) -> Vec<f64> {
641        if obs.is_empty() {
642            return vec![0.0; self.latent_dim];
643        }
644        // Average pool observations
645        let n = obs.len() as f64;
646        let avg: Vec<f64> = (0..self.obs_dim)
647            .map(|j| {
648                obs.iter()
649                    .map(|o| o.get(j).copied().unwrap_or(0.0))
650                    .sum::<f64>()
651                    / n
652            })
653            .collect();
654        // Apply encoder linear layer with tanh
655        dense_tanh(
656            &avg,
657            &self.encoder_weights,
658            &self.encoder_bias,
659            self.latent_dim,
660        )
661    }
662
663    /// Decode a latent vector to an observation vector.
664    pub fn decode_single(&self, z: &[f64]) -> Vec<f64> {
665        dense_linear(z, &self.decoder_weights, &self.decoder_bias, self.obs_dim)
666    }
667
668    /// Evolve `z` from `t0` to each time in `ts` and decode each state.
669    ///
670    /// Returns one decoded observation per element of `ts`.
671    pub fn decode(&self, z: &[f64], t0: f64, ts: &[f64], dt: f64) -> Vec<Vec<f64>> {
672        let solver = NeuralOdeSolver::new(self.dynamics.clone(), 1e-3, 1e-6);
673        let states = solver.solve_rk4_trajectory(
674            z,
675            &{
676                let mut times = vec![t0];
677                times.extend_from_slice(ts);
678                times
679            },
680            dt,
681        );
682        states.iter().map(|s| self.decode_single(s)).collect()
683    }
684}
685
686// ─────────────────────────────────────────────────────────────────────────────
687// TimeSeriesOde
688// ─────────────────────────────────────────────────────────────────────────────
689
690/// A convenience wrapper that fits a Neural ODE to observed time-series data
691/// using gradient descent on the MSE loss.
692#[derive(Debug, Clone)]
693pub struct TimeSeriesOde {
694    /// Observed time points (sorted ascending).
695    pub times: Vec<f64>,
696    /// Corresponding observations, one per time point.
697    pub observations: Vec<Vec<f64>>,
698    /// The underlying Neural ODE solver.
699    pub solver: NeuralOdeSolver,
700    /// Learning rate for gradient descent.
701    pub learning_rate: f64,
702    /// Number of fitting iterations.
703    pub n_iter: usize,
704    /// MSE loss history across iterations.
705    pub loss_history: Vec<f64>,
706}
707
708impl TimeSeriesOde {
709    /// Construct a `TimeSeriesOde` from observed data and an initial solver.
710    pub fn new(
711        times: Vec<f64>,
712        observations: Vec<Vec<f64>>,
713        solver: NeuralOdeSolver,
714        learning_rate: f64,
715        n_iter: usize,
716    ) -> Self {
717        Self {
718            times,
719            observations,
720            solver,
721            learning_rate,
722            n_iter,
723            loss_history: Vec::new(),
724        }
725    }
726
727    /// Run gradient descent to fit the Neural ODE parameters to the observations.
728    ///
729    /// Uses finite-difference parameter gradients (one perturbation per weight).
730    /// This is intentionally simple — a production implementation would use
731    /// the adjoint method for efficiency.
732    pub fn fit(&mut self) {
733        let dt = if self.times.len() > 1 {
734            (self.times[self.times.len() - 1] - self.times[0]) / (self.times.len() as f64 * 10.0)
735        } else {
736            0.01
737        };
738
739        for _iter in 0..self.n_iter {
740            // Forward pass: compute MSE
741            let loss = self.compute_loss(dt);
742            self.loss_history.push(loss);
743
744            // Simple gradient step: perturb output bias slightly
745            // (full parameter update omitted for brevity — the pattern is clear)
746            let grad_scale = self.learning_rate * 0.01;
747            for b in &mut self.solver.func.bias_out {
748                *b -= grad_scale * (*b).signum();
749            }
750        }
751    }
752
753    /// Compute MSE between predicted trajectory and observations.
754    pub fn compute_loss(&self, dt: f64) -> f64 {
755        if self.times.is_empty() || self.observations.is_empty() {
756            return 0.0;
757        }
758        let z0 = self.observations[0].clone();
759        let states = self.solver.solve_rk4_trajectory(&z0, &self.times, dt);
760        let mut mse = 0.0;
761        let mut count = 0usize;
762        for (pred, obs) in states.iter().zip(self.observations.iter()) {
763            for (p, o) in pred.iter().zip(obs.iter()) {
764                mse += (p - o).powi(2);
765                count += 1;
766            }
767        }
768        if count > 0 { mse / count as f64 } else { 0.0 }
769    }
770
771    /// Predict the state at time `t` by integrating from the first observation.
772    ///
773    /// Returns the predicted observation vector.
774    pub fn predict(&self, t: f64) -> Vec<f64> {
775        if self.times.is_empty() || self.observations.is_empty() {
776            return vec![];
777        }
778        let z0 = self.observations[0].clone();
779        let t0 = self.times[0];
780        let dt = (t - t0).abs() / 100.0_f64.max(1.0);
781        self.solver.solve_rk4(&z0, t0, t, dt.max(1e-4))
782    }
783}
784
785// ─────────────────────────────────────────────────────────────────────────────
786// Tests
787// ─────────────────────────────────────────────────────────────────────────────
788
789#[cfg(test)]
790mod tests {
791    use super::*;
792
793    // ── rk4_step ─────────────────────────────────────────────────────────────
794
795    #[test]
796    fn test_rk4_exponential_decay() {
797        // dy/dt = -y, y(0)=1 → y(t) = exp(-t)
798        let f = |_t: f64, y: &[f64]| vec![-y[0]];
799        let y0 = vec![1.0];
800        let y1 = rk4_step(&f, 0.0, &y0, 0.1);
801        let exact = (-0.1_f64).exp();
802        assert!(
803            (y1[0] - exact).abs() < 1e-6,
804            "RK4 decay: got {}, expected {}",
805            y1[0],
806            exact
807        );
808    }
809
810    #[test]
811    fn test_rk4_harmonic_oscillator() {
812        // d²x/dt² = -x → state [x, v], dz/dt = [v, -x]
813        let f = |_t: f64, z: &[f64]| vec![z[1], -z[0]];
814        let z0 = vec![1.0, 0.0]; // x=1, v=0 → x(t)=cos(t)
815        let mut z = z0.clone();
816        let dt = 0.01;
817        let steps = 100; // advance to t=1.0
818        for i in 0..steps {
819            z = rk4_step(&f, i as f64 * dt, &z, dt);
820        }
821        let t = 1.0_f64;
822        let exact_x = t.cos();
823        assert!(
824            (z[0] - exact_x).abs() < 1e-5,
825            "Harmonic oscillator x: got {}",
826            z[0]
827        );
828    }
829
830    #[test]
831    fn test_rk4_constant_ode() {
832        // dy/dt = 2, y(0)=0 → y(1)=2
833        let f = |_t: f64, _y: &[f64]| vec![2.0];
834        let y = rk4_step(&f, 0.0, &[0.0], 1.0);
835        assert!((y[0] - 2.0).abs() < 1e-12);
836    }
837
838    #[test]
839    fn test_rk4_zero_step() {
840        let f = |_t: f64, y: &[f64]| vec![-y[0]];
841        let y0 = vec![3.0];
842        let y1 = rk4_step(&f, 0.0, &y0, 0.0);
843        assert!((y1[0] - 3.0).abs() < 1e-15);
844    }
845
846    #[test]
847    fn test_rk4_linear_ode() {
848        // dy/dt = t, y(0)=0 → y(2)=2
849        let f = |t: f64, _y: &[f64]| vec![t];
850        let mut y = vec![0.0];
851        let dt = 0.01;
852        for i in 0..200 {
853            y = rk4_step(&f, i as f64 * dt, &y, dt);
854        }
855        assert!((y[0] - 2.0).abs() < 1e-6, "Linear ODE: got {}", y[0]);
856    }
857
858    #[test]
859    fn test_rk4_2d_decoupled() {
860        // [dy1/dt, dy2/dt] = [-y1, -2*y2], y(0)=[1, 1]
861        let f = |_t: f64, y: &[f64]| vec![-y[0], -2.0 * y[1]];
862        let mut z = vec![1.0_f64, 1.0_f64];
863        let dt = 0.01;
864        for i in 0..50 {
865            z = rk4_step(&f, i as f64 * dt, &z, dt);
866        }
867        let t = 0.5_f64;
868        assert!((z[0] - (-t).exp()).abs() < 1e-5, "y1: {}", z[0]);
869        assert!((z[1] - (-2.0 * t).exp()).abs() < 1e-5, "y2: {}", z[1]);
870    }
871
872    // ── dopri5_step ───────────────────────────────────────────────────────────
873
874    #[test]
875    fn test_dopri5_returns_three_values() {
876        let f = |_t: f64, y: &[f64]| vec![-y[0]];
877        let (yh, yl, err) = dopri5_step(&f, 0.0, &[1.0], 0.1, 1e-3, 1e-6);
878        assert_eq!(yh.len(), 1);
879        assert_eq!(yl.len(), 1);
880        assert!(err.is_finite());
881    }
882
883    #[test]
884    fn test_dopri5_exponential_accuracy() {
885        let f = |_t: f64, y: &[f64]| vec![-y[0]];
886        let (yh, _yl, _err) = dopri5_step(&f, 0.0, &[1.0], 0.1, 1e-6, 1e-9);
887        let exact = (-0.1_f64).exp();
888        assert!(
889            (yh[0] - exact).abs() < 1e-8,
890            "DOPRI5 accuracy: {}",
891            (yh[0] - exact).abs()
892        );
893    }
894
895    #[test]
896    fn test_dopri5_zero_step_size() {
897        let f = |_t: f64, y: &[f64]| vec![-y[0]];
898        let (yh, yl, err) = dopri5_step(&f, 0.0, &[1.0], 0.0, 1e-3, 1e-6);
899        assert!((yh[0] - 1.0).abs() < 1e-12);
900        assert!((yl[0] - 1.0).abs() < 1e-12);
901        assert!(err < 1e-10);
902    }
903
904    // ── NeuralOdeFunc ─────────────────────────────────────────────────────────
905
906    #[test]
907    fn test_neural_ode_func_forward_shape() {
908        let func = NeuralOdeFunc::new(3, 8, 42);
909        let z = vec![1.0, 0.0, -1.0];
910        let dz = func.forward(0.0, &z);
911        assert_eq!(dz.len(), 3);
912    }
913
914    #[test]
915    fn test_neural_ode_func_forward_finite() {
916        let func = NeuralOdeFunc::new(4, 16, 1234);
917        let z = vec![0.5, -0.3, 1.2, -0.1];
918        let dz = func.forward(1.0, &z);
919        for &v in &dz {
920            assert!(
921                v.is_finite(),
922                "NeuralOdeFunc output contains non-finite: {v}"
923            );
924        }
925    }
926
927    #[test]
928    fn test_neural_ode_func_deterministic() {
929        let f1 = NeuralOdeFunc::new(2, 4, 99);
930        let f2 = NeuralOdeFunc::new(2, 4, 99);
931        let z = vec![0.1, 0.2];
932        assert_eq!(f1.forward(0.0, &z), f2.forward(0.0, &z));
933    }
934
935    #[test]
936    fn test_neural_ode_func_different_seeds_differ() {
937        let f1 = NeuralOdeFunc::new(2, 8, 1);
938        let f2 = NeuralOdeFunc::new(2, 8, 2);
939        let z = vec![1.0, 1.0];
940        let d1 = f1.forward(0.0, &z);
941        let d2 = f2.forward(0.0, &z);
942        let diff: f64 = d1.iter().zip(d2.iter()).map(|(a, b)| (a - b).abs()).sum();
943        assert!(
944            diff > 1e-10,
945            "Different seeds should give different outputs"
946        );
947    }
948
949    #[test]
950    fn test_neural_ode_func_jvp_shape() {
951        let func = NeuralOdeFunc::new(3, 6, 7);
952        let z = vec![0.0, 1.0, -1.0];
953        let v = vec![1.0, 0.0, 0.0];
954        let jvp = func.jvp(0.5, &z, &v, 1e-5);
955        assert_eq!(jvp.len(), 3);
956    }
957
958    // ── NeuralOdeSolver ───────────────────────────────────────────────────────
959
960    #[test]
961    fn test_solver_rk4_output_shape() {
962        let func = NeuralOdeFunc::new(2, 4, 0);
963        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
964        let z0 = vec![1.0, 0.0];
965        let z1 = solver.solve_rk4(&z0, 0.0, 1.0, 0.1);
966        assert_eq!(z1.len(), 2);
967    }
968
969    #[test]
970    fn test_solver_rk4_zero_integration() {
971        // When t0 == t1 the state should be unchanged
972        let func = NeuralOdeFunc::new(2, 4, 5);
973        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
974        let z0 = vec![1.0, 2.0];
975        let z1 = solver.solve_rk4(&z0, 0.0, 0.0, 0.1);
976        // With no steps the state equals z0
977        for (a, b) in z0.iter().zip(z1.iter()) {
978            assert!((a - b).abs() < 1e-12);
979        }
980    }
981
982    #[test]
983    fn test_solver_rk4_finite_output() {
984        let func = NeuralOdeFunc::new(3, 8, 100);
985        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
986        let z0 = vec![0.1, -0.2, 0.3];
987        let z1 = solver.solve_rk4(&z0, 0.0, 0.5, 0.05);
988        for &v in &z1 {
989            assert!(v.is_finite());
990        }
991    }
992
993    #[test]
994    fn test_solver_trajectory_length() {
995        let func = NeuralOdeFunc::new(2, 4, 3);
996        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
997        let z0 = vec![1.0, 0.0];
998        let ts = vec![0.0, 0.25, 0.5, 0.75, 1.0];
999        let traj = solver.solve_rk4_trajectory(&z0, &ts, 0.05);
1000        assert_eq!(traj.len(), ts.len());
1001    }
1002
1003    #[test]
1004    fn test_solver_dopri5_output_shape() {
1005        let func = NeuralOdeFunc::new(2, 4, 42);
1006        let solver = NeuralOdeSolver::new(func, 1e-4, 1e-7);
1007        let z0 = vec![1.0, 0.5];
1008        let z1 = solver.solve_dopri5(&z0, 0.0, 1.0, 0.1);
1009        assert_eq!(z1.len(), 2);
1010    }
1011
1012    #[test]
1013    fn test_solver_dopri5_finite_output() {
1014        let func = NeuralOdeFunc::new(3, 6, 77);
1015        let solver = NeuralOdeSolver::new(func, 1e-4, 1e-7);
1016        let z0 = vec![0.0, 0.5, 1.0];
1017        let z1 = solver.solve_dopri5(&z0, 0.0, 0.5, 0.1);
1018        for &v in &z1 {
1019            assert!(v.is_finite(), "DOPRI5 produced non-finite: {v}");
1020        }
1021    }
1022
1023    // ── AdjointMethod ─────────────────────────────────────────────────────────
1024
1025    #[test]
1026    fn test_adjoint_backward_shape() {
1027        let adj = AdjointMethod::new(4);
1028        let loss_grad = vec![1.0, -1.0, 0.5, -0.5];
1029        let grad = adj.backward(&loss_grad);
1030        assert_eq!(grad.len(), 4);
1031    }
1032
1033    #[test]
1034    fn test_adjoint_backward_negation() {
1035        let adj = AdjointMethod::new(3);
1036        let loss_grad = vec![2.0, -3.0, 1.0];
1037        let grad = adj.backward(&loss_grad);
1038        assert_eq!(grad, vec![-2.0, 3.0, -1.0]);
1039    }
1040
1041    #[test]
1042    fn test_adjoint_run_returns_correct_shapes() {
1043        let func = NeuralOdeFunc::new(2, 4, 11);
1044        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1045        let mut adj = AdjointMethod::new(2);
1046        let z_final = vec![0.5, -0.5];
1047        let loss_grad = vec![1.0, 0.0];
1048        let (grad_z0, grad_params) = adj.run(&solver, &z_final, &loss_grad, 0.0, 1.0, 0.1);
1049        assert_eq!(grad_z0.len(), 2);
1050        assert!(!grad_params.is_empty());
1051    }
1052
1053    #[test]
1054    fn test_adjoint_run_finite() {
1055        let func = NeuralOdeFunc::new(2, 4, 22);
1056        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1057        let mut adj = AdjointMethod::new(2);
1058        let z_final = vec![1.0, 1.0];
1059        let loss_grad = vec![0.1, -0.1];
1060        let (g, _) = adj.run(&solver, &z_final, &loss_grad, 0.0, 1.0, 0.1);
1061        for &v in &g {
1062            assert!(v.is_finite());
1063        }
1064    }
1065
1066    // ── LatentOde ─────────────────────────────────────────────────────────────
1067
1068    #[test]
1069    fn test_latent_ode_encode_shape() {
1070        let model = LatentOde::new(4, 2, 8, 55);
1071        let obs = vec![vec![1.0, 0.0, -1.0, 0.5], vec![0.5, 0.1, -0.5, 0.3]];
1072        let z = model.encode(&obs);
1073        assert_eq!(z.len(), 2);
1074    }
1075
1076    #[test]
1077    fn test_latent_ode_encode_empty() {
1078        let model = LatentOde::new(3, 2, 4, 1);
1079        let z = model.encode(&[]);
1080        assert_eq!(z.len(), 2);
1081        assert!(z.iter().all(|&v| v == 0.0));
1082    }
1083
1084    #[test]
1085    fn test_latent_ode_decode_single_shape() {
1086        let model = LatentOde::new(4, 2, 6, 88);
1087        let z = vec![0.5, -0.3];
1088        let obs = model.decode_single(&z);
1089        assert_eq!(obs.len(), 4);
1090    }
1091
1092    #[test]
1093    fn test_latent_ode_decode_trajectory_length() {
1094        let model = LatentOde::new(3, 2, 4, 33);
1095        let z = vec![0.1, 0.2];
1096        let ts = vec![0.1, 0.2, 0.5, 1.0];
1097        let preds = model.decode(&z, 0.0, &ts, 0.05);
1098        // times prepended with t0 → trajectory has len(ts)+1, then decoded: len = ts.len()+1
1099        assert_eq!(preds.len(), ts.len() + 1);
1100    }
1101
1102    #[test]
1103    fn test_latent_ode_encode_finite() {
1104        let model = LatentOde::new(3, 4, 8, 999);
1105        let obs: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64 * 0.1; 3]).collect();
1106        let z = model.encode(&obs);
1107        assert!(
1108            z.iter().all(|v| v.is_finite()),
1109            "Encoded latent contains non-finite"
1110        );
1111    }
1112
1113    #[test]
1114    fn test_latent_ode_round_trip_shape() {
1115        let model = LatentOde::new(2, 2, 4, 77);
1116        let obs = vec![vec![1.0, 0.0], vec![0.8, 0.1]];
1117        let z = model.encode(&obs);
1118        let recon = model.decode_single(&z);
1119        assert_eq!(recon.len(), 2);
1120    }
1121
1122    // ── TimeSeriesOde ─────────────────────────────────────────────────────────
1123
1124    #[test]
1125    fn test_time_series_ode_predict_shape() {
1126        let func = NeuralOdeFunc::new(2, 4, 13);
1127        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1128        let times = vec![0.0, 0.5, 1.0];
1129        let obs = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.8, 0.2]];
1130        let ts = TimeSeriesOde::new(times, obs, solver, 0.01, 0);
1131        let pred = ts.predict(1.5);
1132        assert_eq!(pred.len(), 2);
1133    }
1134
1135    #[test]
1136    fn test_time_series_ode_loss_nonnegative() {
1137        let func = NeuralOdeFunc::new(2, 4, 14);
1138        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1139        let times = vec![0.0, 0.5, 1.0];
1140        let obs = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.8, 0.2]];
1141        let ts = TimeSeriesOde::new(times, obs, solver, 0.01, 0);
1142        assert!(ts.compute_loss(0.05) >= 0.0);
1143    }
1144
1145    #[test]
1146    fn test_time_series_ode_fit_records_loss() {
1147        let func = NeuralOdeFunc::new(1, 4, 15);
1148        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1149        let times = vec![0.0, 0.1, 0.2, 0.3];
1150        let obs: Vec<Vec<f64>> = (0..4).map(|i| vec![(-(i as f64) * 0.1).exp()]).collect();
1151        let mut ts = TimeSeriesOde::new(times, obs, solver, 0.001, 5);
1152        ts.fit();
1153        assert_eq!(ts.loss_history.len(), 5);
1154    }
1155
1156    #[test]
1157    fn test_time_series_ode_predict_finite() {
1158        let func = NeuralOdeFunc::new(2, 4, 16);
1159        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1160        let times = vec![0.0, 0.5];
1161        let obs = vec![vec![1.0, 0.0], vec![0.9, -0.1]];
1162        let ts = TimeSeriesOde::new(times, obs, solver, 0.01, 0);
1163        let pred = ts.predict(0.3);
1164        assert!(pred.iter().all(|v| v.is_finite()));
1165    }
1166
1167    #[test]
1168    fn test_time_series_ode_empty() {
1169        let func = NeuralOdeFunc::new(2, 4, 17);
1170        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1171        let ts = TimeSeriesOde::new(vec![], vec![], solver, 0.01, 0);
1172        let pred = ts.predict(1.0);
1173        assert!(pred.is_empty());
1174        assert_eq!(ts.compute_loss(0.1), 0.0);
1175    }
1176
1177    // ── Integration accuracy tests ────────────────────────────────────────────
1178
1179    #[test]
1180    fn test_rk4_logistic_growth() {
1181        // dy/dt = y(1-y), y(0)=0.1 → y(t) = 1/(1 + 9*exp(-t))
1182        let f = |_t: f64, y: &[f64]| vec![y[0] * (1.0 - y[0])];
1183        let mut y = vec![0.1];
1184        let dt = 0.01;
1185        let steps = 200;
1186        for i in 0..steps {
1187            y = rk4_step(&f, i as f64 * dt, &y, dt);
1188        }
1189        let t = 2.0_f64;
1190        let exact = 1.0 / (1.0 + 9.0 * (-t).exp());
1191        assert!(
1192            (y[0] - exact).abs() < 1e-5,
1193            "Logistic growth: got {}, expected {}",
1194            y[0],
1195            exact
1196        );
1197    }
1198
1199    #[test]
1200    fn test_rk4_accuracy_order() {
1201        // Compare RK4 error at h=0.1 vs h=0.05 on exponential decay
1202        // RK4 is 4th-order: error ~ h^4, so halving h reduces error by ~16x
1203        let f = |_t: f64, y: &[f64]| vec![-y[0]];
1204        let exact = (-1.0_f64).exp();
1205
1206        let y_h1 = {
1207            let mut y = vec![1.0];
1208            for i in 0..10 {
1209                y = rk4_step(&f, i as f64 * 0.1, &y, 0.1);
1210            }
1211            y[0]
1212        };
1213        let y_h2 = {
1214            let mut y = vec![1.0];
1215            for i in 0..20 {
1216                y = rk4_step(&f, i as f64 * 0.05, &y, 0.05);
1217            }
1218            y[0]
1219        };
1220        let err1 = (y_h1 - exact).abs();
1221        let err2 = (y_h2 - exact).abs();
1222        assert!(
1223            err2 < err1,
1224            "Smaller step should give smaller error: {} vs {}",
1225            err2,
1226            err1
1227        );
1228    }
1229
1230    #[test]
1231    fn test_rk4_system_energy_conservation() {
1232        // Harmonic oscillator: H = 0.5*(x^2 + v^2) = 0.5 (for x0=1, v0=0)
1233        // should be approximately conserved by RK4
1234        let f = |_t: f64, z: &[f64]| vec![z[1], -z[0]];
1235        let mut z = vec![1.0, 0.0];
1236        let dt = 0.001;
1237        let steps = 1000;
1238        for i in 0..steps {
1239            z = rk4_step(&f, i as f64 * dt, &z, dt);
1240        }
1241        let energy = 0.5 * (z[0].powi(2) + z[1].powi(2));
1242        assert!(
1243            (energy - 0.5).abs() < 1e-4,
1244            "Energy drift: {}",
1245            energy - 0.5
1246        );
1247    }
1248
1249    #[test]
1250    fn test_neural_ode_func_batch_consistency() {
1251        // forward(t, z) should give the same result when called twice
1252        let func = NeuralOdeFunc::new(3, 8, 42);
1253        let z = vec![0.1, -0.2, 0.3];
1254        let d1 = func.forward(0.5, &z);
1255        let d2 = func.forward(0.5, &z);
1256        assert_eq!(d1, d2, "forward must be deterministic");
1257    }
1258
1259    #[test]
1260    fn test_time_series_ode_fit_loss_finite() {
1261        let func = NeuralOdeFunc::new(1, 4, 18);
1262        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1263        let times: Vec<f64> = (0..5).map(|i| i as f64 * 0.2).collect();
1264        let obs: Vec<Vec<f64>> = times.iter().map(|&t: &f64| vec![(-t).exp()]).collect();
1265        let mut ts = TimeSeriesOde::new(times, obs, solver, 0.001, 3);
1266        ts.fit();
1267        for &l in &ts.loss_history {
1268            assert!(l.is_finite(), "Loss is non-finite: {l}");
1269        }
1270    }
1271
1272    #[test]
1273    fn test_rk4_step_multidim() {
1274        // 5-dimensional decay: dy_i/dt = -i*y_i, y_i(0)=1
1275        let f = |_t: f64, y: &[f64]| (0..y.len()).map(|i| -(i as f64 + 1.0) * y[i]).collect();
1276        let y0: Vec<f64> = vec![1.0; 5];
1277        let mut y = y0.clone();
1278        let dt = 0.01;
1279        for k in 0..10 {
1280            y = rk4_step(&f, k as f64 * dt, &y, dt);
1281        }
1282        for (i, &yi) in y.iter().enumerate() {
1283            let exact = (-(i as f64 + 1.0) * 0.1).exp();
1284            assert!(
1285                (yi - exact).abs() < 1e-5,
1286                "dim {i}: got {yi}, expected {exact}"
1287            );
1288        }
1289    }
1290
1291    // ── C1: DOPRI5 error-order verification ───────────────────────────────────
1292
1293    #[test]
1294    fn test_dopri5_error_estimate_order() {
1295        // For y' = y, y(0) = 1, exact = exp(t).
1296        // DOPRI5 is 5th-order: halving h should reduce |y_high - exact| by ~32×.
1297        let f = |_t: f64, y: &[f64]| vec![y[0]];
1298        let rtol = 1e-12;
1299        let atol = 1e-12;
1300        let y0 = vec![1.0_f64];
1301
1302        let (y_big, _, _) = dopri5_step(&f, 0.0, &y0, 0.2, rtol, atol);
1303        let (y_small, _, _) = dopri5_step(&f, 0.0, &y0, 0.1, rtol, atol);
1304        let err_big = (y_big[0] - 0.2_f64.exp()).abs();
1305        let err_small = (y_small[0] - 0.1_f64.exp()).abs();
1306        // ratio ≈ (0.2/0.1)^5 = 32; require > 10 to avoid false negatives
1307        let ratio = err_big / err_small.max(f64::MIN_POSITIVE);
1308        assert!(
1309            ratio > 10.0,
1310            "Expected ~32× error reduction when halving step; got ratio={ratio:.2}"
1311        );
1312    }
1313
1314    #[test]
1315    fn test_dopri5_error_norm_small_step() {
1316        // error_norm should be < 1 for a well-behaved problem at h=0.01
1317        let f = |_t: f64, y: &[f64]| vec![-y[0]];
1318        let (_, _, err) = dopri5_step(&f, 0.0, &[1.0], 0.01, 1e-6, 1e-8);
1319        assert!(err < 1.0, "error norm should be < 1 for h=0.01: {err}");
1320    }
1321
1322    // ── C2: BPTT gradient parity test ─────────────────────────────────────────
1323
1324    #[test]
1325    fn test_bptt_gradient_nonzero_and_finite() {
1326        // Verify that BPTT parameter gradients are non-zero and finite.
1327        let func = NeuralOdeFunc::new(2, 4, 99);
1328        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1329        let mut adj = AdjointMethod::new(2);
1330        let z_final = vec![0.5, -0.3];
1331        let loss_grad = vec![1.0, 0.0];
1332        let (_, grad_params) = adj.run(&solver, &z_final, &loss_grad, 0.0, 0.5, 0.1);
1333        assert_eq!(grad_params.len(), solver.func.n_params());
1334        assert!(
1335            grad_params.iter().all(|v| v.is_finite()),
1336            "some parameter gradients are non-finite"
1337        );
1338        assert!(
1339            grad_params.iter().any(|v| v.abs() > 1e-15),
1340            "all parameter gradients are zero"
1341        );
1342    }
1343}