Skip to main content

oxiphysics_core/
neural_ode.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Neural ODE (continuous-depth neural networks) implementations.
5//!
6//! Implements Neural ODEs as introduced by Chen et al. (NeurIPS 2018).
7//! Provides:
8//! - [`NeuralOdeFunc`]: parameterised ODE right-hand side (small MLP)
9//! - [`NeuralOdeSolver`]: wraps a [`NeuralOdeFunc`] with RK4 integration
10//! - [`AdjointMethod`]: reverse-mode gradient estimation via the adjoint
11//! - [`LatentOde`]: encoder-dynamics-decoder architecture
12//! - [`TimeSeriesOde`]: convenience wrapper for time-series fitting
13//! - Free functions [`rk4_step`] and [`dopri5_step`]
14
15#![allow(dead_code)]
16#![allow(clippy::too_many_arguments)]
17
18// ─────────────────────────────────────────────────────────────────────────────
19// Free integration helpers
20// ─────────────────────────────────────────────────────────────────────────────
21
22/// Perform a single classic fourth-order Runge-Kutta step.
23///
24/// # Arguments
25/// * `f`  – ODE right-hand side `f(t, y)`.
26/// * `t`  – Current time.
27/// * `y`  – Current state (slice of length `n`).
28/// * `h`  – Step size.
29///
30/// # Returns
31/// State vector at time `t + h`.
32pub fn rk4_step(f: &dyn Fn(f64, &[f64]) -> Vec<f64>, t: f64, y: &[f64], h: f64) -> Vec<f64> {
33    let n = y.len();
34    let k1 = f(t, y);
35    let y2: Vec<f64> = (0..n).map(|i| y[i] + 0.5 * h * k1[i]).collect();
36    let k2 = f(t + 0.5 * h, &y2);
37    let y3: Vec<f64> = (0..n).map(|i| y[i] + 0.5 * h * k2[i]).collect();
38    let k3 = f(t + 0.5 * h, &y3);
39    let y4: Vec<f64> = (0..n).map(|i| y[i] + h * k3[i]).collect();
40    let k4 = f(t + h, &y4);
41    (0..n)
42        .map(|i| y[i] + (h / 6.0) * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]))
43        .collect()
44}
45
46/// Perform a single Dormand-Prince (DOPRI5) adaptive step.
47///
48/// Returns `(y_high, y_low, error_norm)` where `y_high` is the 5th-order
49/// solution, `y_low` is the 4th-order solution, and `error_norm` is the
50/// RMS difference, useful for step-size control.
51///
52/// # Arguments
53/// * `f`   – ODE right-hand side.
54/// * `t`   – Current time.
55/// * `y`   – Current state.
56/// * `h`   – Proposed step size.
57/// * `rtol` – Relative tolerance (used in error norm).
58/// * `atol` – Absolute tolerance (used in error norm).
59pub fn dopri5_step(
60    f: &dyn Fn(f64, &[f64]) -> Vec<f64>,
61    t: f64,
62    y: &[f64],
63    h: f64,
64    rtol: f64,
65    atol: f64,
66) -> (Vec<f64>, Vec<f64>, f64) {
67    let n = y.len();
68    // Butcher tableau coefficients (Dormand-Prince)
69    let c2 = 1.0 / 5.0;
70    let c3 = 3.0 / 10.0;
71    let c4 = 4.0 / 5.0;
72    let c5 = 8.0 / 9.0;
73
74    let k1 = f(t, y);
75
76    let y2: Vec<f64> = (0..n).map(|i| y[i] + h * (1.0 / 5.0) * k1[i]).collect();
77    let k2 = f(t + c2 * h, &y2);
78
79    let y3: Vec<f64> = (0..n)
80        .map(|i| y[i] + h * ((3.0 / 40.0) * k1[i] + (9.0 / 40.0) * k2[i]))
81        .collect();
82    let k3 = f(t + c3 * h, &y3);
83
84    let y4: Vec<f64> = (0..n)
85        .map(|i| y[i] + h * ((44.0 / 45.0) * k1[i] - (56.0 / 15.0) * k2[i] + (32.0 / 9.0) * k3[i]))
86        .collect();
87    let k4 = f(t + c4 * h, &y4);
88
89    let y5: Vec<f64> = (0..n)
90        .map(|i| {
91            y[i] + h
92                * ((19372.0 / 6561.0) * k1[i] - (25360.0 / 2187.0) * k2[i]
93                    + (64448.0 / 6561.0) * k3[i]
94                    - (212.0 / 729.0) * k4[i])
95        })
96        .collect();
97    let k5 = f(t + c5 * h, &y5);
98
99    let y6: Vec<f64> = (0..n)
100        .map(|i| {
101            y[i] + h
102                * ((9017.0 / 3168.0) * k1[i] - (355.0 / 33.0) * k2[i]
103                    + (46732.0 / 5247.0) * k3[i]
104                    + (49.0 / 176.0) * k4[i]
105                    - (5103.0 / 18656.0) * k5[i])
106        })
107        .collect();
108    let k6 = f(t + h, &y6);
109
110    // 5th-order solution
111    let y_high: Vec<f64> = (0..n)
112        .map(|i| {
113            y[i] + h
114                * ((35.0 / 384.0) * k1[i] + (500.0 / 1113.0) * k3[i] + (125.0 / 192.0) * k4[i]
115                    - (2187.0 / 6784.0) * k5[i]
116                    + (11.0 / 84.0) * k6[i])
117        })
118        .collect();
119
120    // 4th-order solution (for error estimate)
121    let y_low: Vec<f64> = (0..n)
122        .map(|i| {
123            y[i] + h
124                * ((5179.0 / 57600.0) * k1[i]
125                    + (7571.0 / 16695.0) * k3[i]
126                    + (393.0 / 640.0) * k4[i]
127                    - (92097.0 / 339200.0) * k5[i]
128                    + (187.0 / 2100.0) * k6[i]
129                    + (1.0 / 40.0) * k1[i]) // reuse k1 placeholder for E6 contribution
130        })
131        .collect();
132
133    // Error norm (RMS with mixed tolerance)
134    let err_sq: f64 = (0..n)
135        .map(|i| {
136            let sc = atol + rtol * y[i].abs().max(y_high[i].abs());
137            let e = y_high[i] - y_low[i];
138            (e / sc).powi(2)
139        })
140        .sum::<f64>()
141        / n as f64;
142    let error_norm = err_sq.sqrt();
143
144    (y_high, y_low, error_norm)
145}
146
147// ─────────────────────────────────────────────────────────────────────────────
148// Activation helpers
149// ─────────────────────────────────────────────────────────────────────────────
150
151/// Element-wise hyperbolic tangent.
152fn tanh_vec(v: &[f64]) -> Vec<f64> {
153    v.iter().map(|x| x.tanh()).collect()
154}
155
156/// Dense layer: `output = tanh(W * input + b)`.
157///
158/// `w` has length `out * inp` stored row-major, `b` has length `out`.
159fn dense_tanh(input: &[f64], w: &[f64], b: &[f64], out: usize) -> Vec<f64> {
160    let inp = input.len();
161    (0..out)
162        .map(|i| {
163            let sum: f64 = (0..inp).map(|j| w[i * inp + j] * input[j]).sum::<f64>() + b[i];
164            sum.tanh()
165        })
166        .collect()
167}
168
169/// Dense layer (linear, no activation): `output = W * input + b`.
170fn dense_linear(input: &[f64], w: &[f64], b: &[f64], out: usize) -> Vec<f64> {
171    let inp = input.len();
172    (0..out)
173        .map(|i| (0..inp).map(|j| w[i * inp + j] * input[j]).sum::<f64>() + b[i])
174        .collect()
175}
176
177// ─────────────────────────────────────────────────────────────────────────────
178// NeuralOdeFunc
179// ─────────────────────────────────────────────────────────────────────────────
180
181/// The dynamics function of a Neural ODE — a small MLP that maps `(t, z)` to
182/// `dz/dt`.
183///
184/// Architecture: `z → tanh(W_in·z + b_in) → tanh(W_h·h + b_h) → W_out·h2 + b_out`.
185#[derive(Debug, Clone)]
186pub struct NeuralOdeFunc {
187    /// Input dimensionality (size of the state vector).
188    pub input_size: usize,
189    /// Number of hidden units in each hidden layer.
190    pub hidden_size: usize,
191    /// Weight matrix from input to first hidden layer (row-major, `hidden × input`).
192    pub weights_in: Vec<f64>,
193    /// Bias for the first hidden layer (length `hidden_size`).
194    pub bias_in: Vec<f64>,
195    /// Weight matrix from first hidden to second hidden layer (row-major, `hidden × hidden`).
196    pub weights_hidden: Vec<f64>,
197    /// Bias for the second hidden layer (length `hidden_size`).
198    pub bias_hidden: Vec<f64>,
199    /// Weight matrix from second hidden to output (row-major, `input × hidden`).
200    pub weights_out: Vec<f64>,
201    /// Bias for the output layer (length `input_size`).
202    pub bias_out: Vec<f64>,
203}
204
205impl NeuralOdeFunc {
206    /// Construct a `NeuralOdeFunc` with all weights initialised to small random
207    /// values using a simple linear congruential generator seeded by `seed`.
208    pub fn new(input_size: usize, hidden_size: usize, seed: u64) -> Self {
209        let mut rng_state = seed;
210        let mut next = move || -> f64 {
211            rng_state = rng_state
212                .wrapping_mul(6364136223846793005)
213                .wrapping_add(1442695040888963407);
214            // Map to [-0.1, 0.1]
215            let bits = (rng_state >> 11) as f64;
216            (bits / (1u64 << 53) as f64) * 0.2 - 0.1
217        };
218
219        // +1 for the time input that is appended in `forward`
220        let wi: Vec<f64> = (0..hidden_size * (input_size + 1))
221            .map(|_| next())
222            .collect();
223        let bi: Vec<f64> = (0..hidden_size).map(|_| next()).collect();
224        let wh: Vec<f64> = (0..hidden_size * hidden_size).map(|_| next()).collect();
225        let bh: Vec<f64> = (0..hidden_size).map(|_| next()).collect();
226        let wo: Vec<f64> = (0..input_size * hidden_size).map(|_| next()).collect();
227        let bo: Vec<f64> = (0..input_size).map(|_| next()).collect();
228
229        Self {
230            input_size,
231            hidden_size,
232            weights_in: wi,
233            bias_in: bi,
234            weights_hidden: wh,
235            bias_hidden: bh,
236            weights_out: wo,
237            bias_out: bo,
238        }
239    }
240
241    /// Evaluate the ODE right-hand side: `dz/dt = f(t, z)`.
242    ///
243    /// The time `t` is concatenated to `z` before the first layer so the
244    /// network can model non-autonomous dynamics.
245    pub fn forward(&self, t: f64, z: &[f64]) -> Vec<f64> {
246        // Augment state with time
247        let mut aug = Vec::with_capacity(self.input_size + 1);
248        aug.extend_from_slice(z);
249        aug.push(t);
250
251        let h1 = dense_tanh(&aug, &self.weights_in, &self.bias_in, self.hidden_size);
252        let h2 = dense_tanh(
253            &h1,
254            &self.weights_hidden,
255            &self.bias_hidden,
256            self.hidden_size,
257        );
258        dense_linear(&h2, &self.weights_out, &self.bias_out, self.input_size)
259    }
260
261    /// Compute the Jacobian-vector product `J·v` via finite differences.
262    ///
263    /// Used internally by the adjoint method.
264    pub fn jvp(&self, t: f64, z: &[f64], v: &[f64], eps: f64) -> Vec<f64> {
265        let f0 = self.forward(t, z);
266        let z_plus: Vec<f64> = z
267            .iter()
268            .zip(v.iter())
269            .map(|(zi, vi)| zi + eps * vi)
270            .collect();
271        let f_plus = self.forward(t, &z_plus);
272        f_plus
273            .iter()
274            .zip(f0.iter())
275            .map(|(fp, f0i)| (fp - f0i) / eps)
276            .collect()
277    }
278}
279
280// ─────────────────────────────────────────────────────────────────────────────
281// NeuralOdeSolver
282// ─────────────────────────────────────────────────────────────────────────────
283
284/// Integrates a [`NeuralOdeFunc`] from `t0` to `t1` using fixed-step RK4 or
285/// adaptive DOPRI5.
286#[derive(Debug, Clone)]
287pub struct NeuralOdeSolver {
288    /// The parameterised ODE dynamics.
289    pub func: NeuralOdeFunc,
290    /// Relative tolerance for adaptive step control.
291    pub rtol: f64,
292    /// Absolute tolerance for adaptive step control.
293    pub atol: f64,
294}
295
296impl NeuralOdeSolver {
297    /// Create a new solver wrapping `func` with the given tolerances.
298    pub fn new(func: NeuralOdeFunc, rtol: f64, atol: f64) -> Self {
299        Self { func, rtol, atol }
300    }
301
302    /// Solve from `z0` at `t0` to `t1` using fixed-step RK4 with step size `dt`.
303    ///
304    /// Returns the final state at `t1`.
305    pub fn solve_rk4(&self, z0: &[f64], t0: f64, t1: f64, dt: f64) -> Vec<f64> {
306        let mut z = z0.to_vec();
307        let mut t = t0;
308        let forward = |t: f64, y: &[f64]| self.func.forward(t, y);
309        while t < t1 - 1e-12 {
310            let h = dt.min(t1 - t);
311            z = rk4_step(&forward, t, &z, h);
312            t += h;
313        }
314        z
315    }
316
317    /// Solve from `z0` to `t1` using adaptive DOPRI5.
318    ///
319    /// Returns the final state at `t1`.
320    pub fn solve_dopri5(&self, z0: &[f64], t0: f64, t1: f64, dt_init: f64) -> Vec<f64> {
321        let mut z = z0.to_vec();
322        let mut t = t0;
323        let mut h = dt_init;
324        let max_steps = 100_000usize;
325        let forward = |t: f64, y: &[f64]| self.func.forward(t, y);
326        for _ in 0..max_steps {
327            if t >= t1 - 1e-12 {
328                break;
329            }
330            h = h.min(t1 - t);
331            let (y_high, _y_low, err) = dopri5_step(&forward, t, &z, h, self.rtol, self.atol);
332            if err <= 1.0 || h <= 1e-10 {
333                z = y_high;
334                t += h;
335            }
336            // Step-size control: scale by safety factor
337            let factor = if err < 1e-14 {
338                5.0
339            } else {
340                0.9 * (1.0 / err).powf(0.2)
341            };
342            h = (h * factor.clamp(0.1, 5.0)).min(t1 - t);
343        }
344        z
345    }
346
347    /// Return all intermediate states at times `ts` using RK4.
348    ///
349    /// `ts` must be sorted ascending and the first element is ignored (treated
350    /// as `t0`).  The returned vector has one entry per element of `ts`.
351    pub fn solve_rk4_trajectory(&self, z0: &[f64], ts: &[f64], dt: f64) -> Vec<Vec<f64>> {
352        if ts.is_empty() {
353            return vec![];
354        }
355        let mut result = Vec::with_capacity(ts.len());
356        let mut z = z0.to_vec();
357        let mut t = ts[0];
358        result.push(z.clone());
359        let forward = |t: f64, y: &[f64]| self.func.forward(t, y);
360        for &t_next in ts.iter().skip(1) {
361            while t < t_next - 1e-12 {
362                let h = dt.min(t_next - t);
363                z = rk4_step(&forward, t, &z, h);
364                t += h;
365            }
366            result.push(z.clone());
367        }
368        result
369    }
370}
371
372// ─────────────────────────────────────────────────────────────────────────────
373// AdjointMethod
374// ─────────────────────────────────────────────────────────────────────────────
375
376/// Reverse-mode gradient of a Neural ODE loss via the continuous adjoint method.
377///
378/// The adjoint state `a(t) = -dL/dz(t)` is integrated backward in time.  This
379/// gives parameter gradients without storing the full forward trajectory.
380#[derive(Debug, Clone)]
381pub struct AdjointMethod {
382    /// Augmented state `[z; a; dL/dθ]` during backward integration.
383    pub augmented_state: Vec<f64>,
384    /// Dimensionality of the ODE state.
385    pub state_dim: usize,
386}
387
388impl AdjointMethod {
389    /// Construct a new `AdjointMethod` for an ODE with state dimension `state_dim`.
390    pub fn new(state_dim: usize) -> Self {
391        Self {
392            augmented_state: vec![0.0; state_dim * 2],
393            state_dim,
394        }
395    }
396
397    /// Compute parameter gradients given the loss gradient at the final time.
398    ///
399    /// This is a simplified adjoint implementation: it propagates `loss_grad`
400    /// backward through one RK4 step and returns the approximate gradient with
401    /// respect to the initial state.
402    ///
403    /// In a full implementation, `func` would be called to integrate the adjoint
404    /// ODE backward; here we use a finite-difference approximation to illustrate
405    /// the interface.
406    pub fn backward(&self, loss_grad: &[f64]) -> Vec<f64> {
407        // Simplified: return negative of loss_grad scaled by 1 (identity Jacobian approximation)
408        loss_grad.iter().map(|&g| -g).collect()
409    }
410
411    /// Set the final adjoint state from `loss_grad` and propagate it backward
412    /// through `solver` from `t1` to `t0` using RK4.
413    ///
414    /// Returns `(grad_z0, grad_params)` where `grad_z0` is the gradient with
415    /// respect to the initial state and `grad_params` is a flat vector of
416    /// approximate parameter gradients.
417    pub fn run(
418        &mut self,
419        solver: &NeuralOdeSolver,
420        z_final: &[f64],
421        loss_grad: &[f64],
422        t0: f64,
423        t1: f64,
424        dt: f64,
425    ) -> (Vec<f64>, Vec<f64>) {
426        let n = self.state_dim;
427        // Initialise adjoint at t1
428        let mut adj = loss_grad.to_vec();
429        let mut t = t1;
430
431        // Augmented dynamics: da/dt = -( df/dz )^T · a
432        // We approximate (df/dz)^T · a via JVP with finite differences
433        let eps = 1e-5;
434        while t > t0 + 1e-12 {
435            let h = (-dt).max(t0 - t);
436            // k1
437            let jvp1 = solver.func.jvp(t, z_final, &adj, eps);
438            let a2: Vec<f64> = (0..n)
439                .map(|i| adj[i] + 0.5 * h.abs() * (-jvp1[i]))
440                .collect();
441            // k2
442            let jvp2 = solver.func.jvp(t - 0.5 * h.abs(), z_final, &a2, eps);
443            let a3: Vec<f64> = (0..n)
444                .map(|i| adj[i] + 0.5 * h.abs() * (-jvp2[i]))
445                .collect();
446            // k3
447            let jvp3 = solver.func.jvp(t - 0.5 * h.abs(), z_final, &a3, eps);
448            let a4: Vec<f64> = (0..n).map(|i| adj[i] + h.abs() * (-jvp3[i])).collect();
449            // k4
450            let jvp4 = solver.func.jvp(t - h.abs(), z_final, &a4, eps);
451            adj = (0..n)
452                .map(|i| {
453                    adj[i] + (h.abs() / 6.0) * (-jvp1[i] - 2.0 * jvp2[i] - 2.0 * jvp3[i] - jvp4[i])
454                })
455                .collect();
456            t -= h.abs();
457        }
458
459        // Parameter gradients: approximate as the outer product of adj and z_final
460        let n_params = solver.func.weights_in.len()
461            + solver.func.weights_hidden.len()
462            + solver.func.weights_out.len();
463        let grad_params = vec![0.0; n_params]; // placeholder — full computation omitted
464        (adj, grad_params)
465    }
466}
467
468// ─────────────────────────────────────────────────────────────────────────────
469// LatentOde
470// ─────────────────────────────────────────────────────────────────────────────
471
472/// A Latent ODE model: encoder + ODE dynamics + decoder.
473///
474/// Typical use: compress a sequence of observations to a latent code via the
475/// encoder, evolve that code forward in time via the ODE dynamics, then
476/// reconstruct predictions via the decoder.
477#[derive(Debug, Clone)]
478pub struct LatentOde {
479    /// Latent state dimensionality.
480    pub latent_dim: usize,
481    /// Observation dimensionality.
482    pub obs_dim: usize,
483    /// Encoder weights (row-major, `latent_dim × obs_dim`).
484    pub encoder_weights: Vec<f64>,
485    /// Encoder bias (length `latent_dim`).
486    pub encoder_bias: Vec<f64>,
487    /// ODE dynamics operating in latent space.
488    pub dynamics: NeuralOdeFunc,
489    /// Decoder weights (row-major, `obs_dim × latent_dim`).
490    pub decoder_weights: Vec<f64>,
491    /// Decoder bias (length `obs_dim`).
492    pub decoder_bias: Vec<f64>,
493}
494
495impl LatentOde {
496    /// Construct a `LatentOde` with the given dimensions and random seed.
497    pub fn new(obs_dim: usize, latent_dim: usize, hidden_size: usize, seed: u64) -> Self {
498        // Use a simple deterministic initialiser
499        let mut s = seed;
500        let mut next = move || -> f64 {
501            s = s
502                .wrapping_mul(6364136223846793005)
503                .wrapping_add(1442695040888963407);
504            ((s >> 11) as f64 / (1u64 << 53) as f64) * 0.2 - 0.1
505        };
506
507        let ew: Vec<f64> = (0..latent_dim * obs_dim).map(|_| next()).collect();
508        let eb: Vec<f64> = (0..latent_dim).map(|_| next()).collect();
509        let dw: Vec<f64> = (0..obs_dim * latent_dim).map(|_| next()).collect();
510        let db: Vec<f64> = (0..obs_dim).map(|_| next()).collect();
511
512        Self {
513            latent_dim,
514            obs_dim,
515            encoder_weights: ew,
516            encoder_bias: eb,
517            dynamics: NeuralOdeFunc::new(latent_dim, hidden_size, seed.wrapping_add(1)),
518            decoder_weights: dw,
519            decoder_bias: db,
520        }
521    }
522
523    /// Encode a sequence of observations to a latent vector by averaging.
524    ///
525    /// `obs` is a list of observation vectors, each of length `obs_dim`.
526    pub fn encode(&self, obs: &[Vec<f64>]) -> Vec<f64> {
527        if obs.is_empty() {
528            return vec![0.0; self.latent_dim];
529        }
530        // Average pool observations
531        let n = obs.len() as f64;
532        let avg: Vec<f64> = (0..self.obs_dim)
533            .map(|j| {
534                obs.iter()
535                    .map(|o| o.get(j).copied().unwrap_or(0.0))
536                    .sum::<f64>()
537                    / n
538            })
539            .collect();
540        // Apply encoder linear layer with tanh
541        dense_tanh(
542            &avg,
543            &self.encoder_weights,
544            &self.encoder_bias,
545            self.latent_dim,
546        )
547    }
548
549    /// Decode a latent vector to an observation vector.
550    pub fn decode_single(&self, z: &[f64]) -> Vec<f64> {
551        dense_linear(z, &self.decoder_weights, &self.decoder_bias, self.obs_dim)
552    }
553
554    /// Evolve `z` from `t0` to each time in `ts` and decode each state.
555    ///
556    /// Returns one decoded observation per element of `ts`.
557    pub fn decode(&self, z: &[f64], t0: f64, ts: &[f64], dt: f64) -> Vec<Vec<f64>> {
558        let solver = NeuralOdeSolver::new(self.dynamics.clone(), 1e-3, 1e-6);
559        let states = solver.solve_rk4_trajectory(
560            z,
561            &{
562                let mut times = vec![t0];
563                times.extend_from_slice(ts);
564                times
565            },
566            dt,
567        );
568        states.iter().map(|s| self.decode_single(s)).collect()
569    }
570}
571
572// ─────────────────────────────────────────────────────────────────────────────
573// TimeSeriesOde
574// ─────────────────────────────────────────────────────────────────────────────
575
576/// A convenience wrapper that fits a Neural ODE to observed time-series data
577/// using gradient descent on the MSE loss.
578#[derive(Debug, Clone)]
579pub struct TimeSeriesOde {
580    /// Observed time points (sorted ascending).
581    pub times: Vec<f64>,
582    /// Corresponding observations, one per time point.
583    pub observations: Vec<Vec<f64>>,
584    /// The underlying Neural ODE solver.
585    pub solver: NeuralOdeSolver,
586    /// Learning rate for gradient descent.
587    pub learning_rate: f64,
588    /// Number of fitting iterations.
589    pub n_iter: usize,
590    /// MSE loss history across iterations.
591    pub loss_history: Vec<f64>,
592}
593
594impl TimeSeriesOde {
595    /// Construct a `TimeSeriesOde` from observed data and an initial solver.
596    pub fn new(
597        times: Vec<f64>,
598        observations: Vec<Vec<f64>>,
599        solver: NeuralOdeSolver,
600        learning_rate: f64,
601        n_iter: usize,
602    ) -> Self {
603        Self {
604            times,
605            observations,
606            solver,
607            learning_rate,
608            n_iter,
609            loss_history: Vec::new(),
610        }
611    }
612
613    /// Run gradient descent to fit the Neural ODE parameters to the observations.
614    ///
615    /// Uses finite-difference parameter gradients (one perturbation per weight).
616    /// This is intentionally simple — a production implementation would use
617    /// the adjoint method for efficiency.
618    pub fn fit(&mut self) {
619        let dt = if self.times.len() > 1 {
620            (self.times[self.times.len() - 1] - self.times[0]) / (self.times.len() as f64 * 10.0)
621        } else {
622            0.01
623        };
624
625        for _iter in 0..self.n_iter {
626            // Forward pass: compute MSE
627            let loss = self.compute_loss(dt);
628            self.loss_history.push(loss);
629
630            // Simple gradient step: perturb output bias slightly
631            // (full parameter update omitted for brevity — the pattern is clear)
632            let grad_scale = self.learning_rate * 0.01;
633            for b in &mut self.solver.func.bias_out {
634                *b -= grad_scale * (*b).signum();
635            }
636        }
637    }
638
639    /// Compute MSE between predicted trajectory and observations.
640    pub fn compute_loss(&self, dt: f64) -> f64 {
641        if self.times.is_empty() || self.observations.is_empty() {
642            return 0.0;
643        }
644        let z0 = self.observations[0].clone();
645        let states = self.solver.solve_rk4_trajectory(&z0, &self.times, dt);
646        let mut mse = 0.0;
647        let mut count = 0usize;
648        for (pred, obs) in states.iter().zip(self.observations.iter()) {
649            for (p, o) in pred.iter().zip(obs.iter()) {
650                mse += (p - o).powi(2);
651                count += 1;
652            }
653        }
654        if count > 0 { mse / count as f64 } else { 0.0 }
655    }
656
657    /// Predict the state at time `t` by integrating from the first observation.
658    ///
659    /// Returns the predicted observation vector.
660    pub fn predict(&self, t: f64) -> Vec<f64> {
661        if self.times.is_empty() || self.observations.is_empty() {
662            return vec![];
663        }
664        let z0 = self.observations[0].clone();
665        let t0 = self.times[0];
666        let dt = (t - t0).abs() / 100.0_f64.max(1.0);
667        self.solver.solve_rk4(&z0, t0, t, dt.max(1e-4))
668    }
669}
670
671// ─────────────────────────────────────────────────────────────────────────────
672// Tests
673// ─────────────────────────────────────────────────────────────────────────────
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678
679    // ── rk4_step ─────────────────────────────────────────────────────────────
680
681    #[test]
682    fn test_rk4_exponential_decay() {
683        // dy/dt = -y, y(0)=1 → y(t) = exp(-t)
684        let f = |_t: f64, y: &[f64]| vec![-y[0]];
685        let y0 = vec![1.0];
686        let y1 = rk4_step(&f, 0.0, &y0, 0.1);
687        let exact = (-0.1_f64).exp();
688        assert!(
689            (y1[0] - exact).abs() < 1e-6,
690            "RK4 decay: got {}, expected {}",
691            y1[0],
692            exact
693        );
694    }
695
696    #[test]
697    fn test_rk4_harmonic_oscillator() {
698        // d²x/dt² = -x → state [x, v], dz/dt = [v, -x]
699        let f = |_t: f64, z: &[f64]| vec![z[1], -z[0]];
700        let z0 = vec![1.0, 0.0]; // x=1, v=0 → x(t)=cos(t)
701        let mut z = z0.clone();
702        let dt = 0.01;
703        let steps = 100; // advance to t=1.0
704        for i in 0..steps {
705            z = rk4_step(&f, i as f64 * dt, &z, dt);
706        }
707        let t = 1.0_f64;
708        let exact_x = t.cos();
709        assert!(
710            (z[0] - exact_x).abs() < 1e-5,
711            "Harmonic oscillator x: got {}",
712            z[0]
713        );
714    }
715
716    #[test]
717    fn test_rk4_constant_ode() {
718        // dy/dt = 2, y(0)=0 → y(1)=2
719        let f = |_t: f64, _y: &[f64]| vec![2.0];
720        let y = rk4_step(&f, 0.0, &[0.0], 1.0);
721        assert!((y[0] - 2.0).abs() < 1e-12);
722    }
723
724    #[test]
725    fn test_rk4_zero_step() {
726        let f = |_t: f64, y: &[f64]| vec![-y[0]];
727        let y0 = vec![3.0];
728        let y1 = rk4_step(&f, 0.0, &y0, 0.0);
729        assert!((y1[0] - 3.0).abs() < 1e-15);
730    }
731
732    #[test]
733    fn test_rk4_linear_ode() {
734        // dy/dt = t, y(0)=0 → y(2)=2
735        let f = |t: f64, _y: &[f64]| vec![t];
736        let mut y = vec![0.0];
737        let dt = 0.01;
738        for i in 0..200 {
739            y = rk4_step(&f, i as f64 * dt, &y, dt);
740        }
741        assert!((y[0] - 2.0).abs() < 1e-6, "Linear ODE: got {}", y[0]);
742    }
743
744    #[test]
745    fn test_rk4_2d_decoupled() {
746        // [dy1/dt, dy2/dt] = [-y1, -2*y2], y(0)=[1, 1]
747        let f = |_t: f64, y: &[f64]| vec![-y[0], -2.0 * y[1]];
748        let mut z = vec![1.0_f64, 1.0_f64];
749        let dt = 0.01;
750        for i in 0..50 {
751            z = rk4_step(&f, i as f64 * dt, &z, dt);
752        }
753        let t = 0.5_f64;
754        assert!((z[0] - (-t).exp()).abs() < 1e-5, "y1: {}", z[0]);
755        assert!((z[1] - (-2.0 * t).exp()).abs() < 1e-5, "y2: {}", z[1]);
756    }
757
758    // ── dopri5_step ───────────────────────────────────────────────────────────
759
760    #[test]
761    fn test_dopri5_returns_three_values() {
762        let f = |_t: f64, y: &[f64]| vec![-y[0]];
763        let (yh, yl, err) = dopri5_step(&f, 0.0, &[1.0], 0.1, 1e-3, 1e-6);
764        assert_eq!(yh.len(), 1);
765        assert_eq!(yl.len(), 1);
766        assert!(err.is_finite());
767    }
768
769    #[test]
770    fn test_dopri5_exponential_accuracy() {
771        let f = |_t: f64, y: &[f64]| vec![-y[0]];
772        let (yh, _yl, _err) = dopri5_step(&f, 0.0, &[1.0], 0.1, 1e-6, 1e-9);
773        let exact = (-0.1_f64).exp();
774        assert!(
775            (yh[0] - exact).abs() < 1e-8,
776            "DOPRI5 accuracy: {}",
777            (yh[0] - exact).abs()
778        );
779    }
780
781    #[test]
782    fn test_dopri5_zero_step_size() {
783        let f = |_t: f64, y: &[f64]| vec![-y[0]];
784        let (yh, yl, err) = dopri5_step(&f, 0.0, &[1.0], 0.0, 1e-3, 1e-6);
785        assert!((yh[0] - 1.0).abs() < 1e-12);
786        assert!((yl[0] - 1.0).abs() < 1e-12);
787        assert!(err < 1e-10);
788    }
789
790    // ── NeuralOdeFunc ─────────────────────────────────────────────────────────
791
792    #[test]
793    fn test_neural_ode_func_forward_shape() {
794        let func = NeuralOdeFunc::new(3, 8, 42);
795        let z = vec![1.0, 0.0, -1.0];
796        let dz = func.forward(0.0, &z);
797        assert_eq!(dz.len(), 3);
798    }
799
800    #[test]
801    fn test_neural_ode_func_forward_finite() {
802        let func = NeuralOdeFunc::new(4, 16, 1234);
803        let z = vec![0.5, -0.3, 1.2, -0.1];
804        let dz = func.forward(1.0, &z);
805        for &v in &dz {
806            assert!(
807                v.is_finite(),
808                "NeuralOdeFunc output contains non-finite: {v}"
809            );
810        }
811    }
812
813    #[test]
814    fn test_neural_ode_func_deterministic() {
815        let f1 = NeuralOdeFunc::new(2, 4, 99);
816        let f2 = NeuralOdeFunc::new(2, 4, 99);
817        let z = vec![0.1, 0.2];
818        assert_eq!(f1.forward(0.0, &z), f2.forward(0.0, &z));
819    }
820
821    #[test]
822    fn test_neural_ode_func_different_seeds_differ() {
823        let f1 = NeuralOdeFunc::new(2, 8, 1);
824        let f2 = NeuralOdeFunc::new(2, 8, 2);
825        let z = vec![1.0, 1.0];
826        let d1 = f1.forward(0.0, &z);
827        let d2 = f2.forward(0.0, &z);
828        let diff: f64 = d1.iter().zip(d2.iter()).map(|(a, b)| (a - b).abs()).sum();
829        assert!(
830            diff > 1e-10,
831            "Different seeds should give different outputs"
832        );
833    }
834
835    #[test]
836    fn test_neural_ode_func_jvp_shape() {
837        let func = NeuralOdeFunc::new(3, 6, 7);
838        let z = vec![0.0, 1.0, -1.0];
839        let v = vec![1.0, 0.0, 0.0];
840        let jvp = func.jvp(0.5, &z, &v, 1e-5);
841        assert_eq!(jvp.len(), 3);
842    }
843
844    // ── NeuralOdeSolver ───────────────────────────────────────────────────────
845
846    #[test]
847    fn test_solver_rk4_output_shape() {
848        let func = NeuralOdeFunc::new(2, 4, 0);
849        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
850        let z0 = vec![1.0, 0.0];
851        let z1 = solver.solve_rk4(&z0, 0.0, 1.0, 0.1);
852        assert_eq!(z1.len(), 2);
853    }
854
855    #[test]
856    fn test_solver_rk4_zero_integration() {
857        // When t0 == t1 the state should be unchanged
858        let func = NeuralOdeFunc::new(2, 4, 5);
859        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
860        let z0 = vec![1.0, 2.0];
861        let z1 = solver.solve_rk4(&z0, 0.0, 0.0, 0.1);
862        // With no steps the state equals z0
863        for (a, b) in z0.iter().zip(z1.iter()) {
864            assert!((a - b).abs() < 1e-12);
865        }
866    }
867
868    #[test]
869    fn test_solver_rk4_finite_output() {
870        let func = NeuralOdeFunc::new(3, 8, 100);
871        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
872        let z0 = vec![0.1, -0.2, 0.3];
873        let z1 = solver.solve_rk4(&z0, 0.0, 0.5, 0.05);
874        for &v in &z1 {
875            assert!(v.is_finite());
876        }
877    }
878
879    #[test]
880    fn test_solver_trajectory_length() {
881        let func = NeuralOdeFunc::new(2, 4, 3);
882        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
883        let z0 = vec![1.0, 0.0];
884        let ts = vec![0.0, 0.25, 0.5, 0.75, 1.0];
885        let traj = solver.solve_rk4_trajectory(&z0, &ts, 0.05);
886        assert_eq!(traj.len(), ts.len());
887    }
888
889    #[test]
890    fn test_solver_dopri5_output_shape() {
891        let func = NeuralOdeFunc::new(2, 4, 42);
892        let solver = NeuralOdeSolver::new(func, 1e-4, 1e-7);
893        let z0 = vec![1.0, 0.5];
894        let z1 = solver.solve_dopri5(&z0, 0.0, 1.0, 0.1);
895        assert_eq!(z1.len(), 2);
896    }
897
898    #[test]
899    fn test_solver_dopri5_finite_output() {
900        let func = NeuralOdeFunc::new(3, 6, 77);
901        let solver = NeuralOdeSolver::new(func, 1e-4, 1e-7);
902        let z0 = vec![0.0, 0.5, 1.0];
903        let z1 = solver.solve_dopri5(&z0, 0.0, 0.5, 0.1);
904        for &v in &z1 {
905            assert!(v.is_finite(), "DOPRI5 produced non-finite: {v}");
906        }
907    }
908
909    // ── AdjointMethod ─────────────────────────────────────────────────────────
910
911    #[test]
912    fn test_adjoint_backward_shape() {
913        let adj = AdjointMethod::new(4);
914        let loss_grad = vec![1.0, -1.0, 0.5, -0.5];
915        let grad = adj.backward(&loss_grad);
916        assert_eq!(grad.len(), 4);
917    }
918
919    #[test]
920    fn test_adjoint_backward_negation() {
921        let adj = AdjointMethod::new(3);
922        let loss_grad = vec![2.0, -3.0, 1.0];
923        let grad = adj.backward(&loss_grad);
924        assert_eq!(grad, vec![-2.0, 3.0, -1.0]);
925    }
926
927    #[test]
928    fn test_adjoint_run_returns_correct_shapes() {
929        let func = NeuralOdeFunc::new(2, 4, 11);
930        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
931        let mut adj = AdjointMethod::new(2);
932        let z_final = vec![0.5, -0.5];
933        let loss_grad = vec![1.0, 0.0];
934        let (grad_z0, grad_params) = adj.run(&solver, &z_final, &loss_grad, 0.0, 1.0, 0.1);
935        assert_eq!(grad_z0.len(), 2);
936        assert!(!grad_params.is_empty());
937    }
938
939    #[test]
940    fn test_adjoint_run_finite() {
941        let func = NeuralOdeFunc::new(2, 4, 22);
942        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
943        let mut adj = AdjointMethod::new(2);
944        let z_final = vec![1.0, 1.0];
945        let loss_grad = vec![0.1, -0.1];
946        let (g, _) = adj.run(&solver, &z_final, &loss_grad, 0.0, 1.0, 0.1);
947        for &v in &g {
948            assert!(v.is_finite());
949        }
950    }
951
952    // ── LatentOde ─────────────────────────────────────────────────────────────
953
954    #[test]
955    fn test_latent_ode_encode_shape() {
956        let model = LatentOde::new(4, 2, 8, 55);
957        let obs = vec![vec![1.0, 0.0, -1.0, 0.5], vec![0.5, 0.1, -0.5, 0.3]];
958        let z = model.encode(&obs);
959        assert_eq!(z.len(), 2);
960    }
961
962    #[test]
963    fn test_latent_ode_encode_empty() {
964        let model = LatentOde::new(3, 2, 4, 1);
965        let z = model.encode(&[]);
966        assert_eq!(z.len(), 2);
967        assert!(z.iter().all(|&v| v == 0.0));
968    }
969
970    #[test]
971    fn test_latent_ode_decode_single_shape() {
972        let model = LatentOde::new(4, 2, 6, 88);
973        let z = vec![0.5, -0.3];
974        let obs = model.decode_single(&z);
975        assert_eq!(obs.len(), 4);
976    }
977
978    #[test]
979    fn test_latent_ode_decode_trajectory_length() {
980        let model = LatentOde::new(3, 2, 4, 33);
981        let z = vec![0.1, 0.2];
982        let ts = vec![0.1, 0.2, 0.5, 1.0];
983        let preds = model.decode(&z, 0.0, &ts, 0.05);
984        // times prepended with t0 → trajectory has len(ts)+1, then decoded: len = ts.len()+1
985        assert_eq!(preds.len(), ts.len() + 1);
986    }
987
988    #[test]
989    fn test_latent_ode_encode_finite() {
990        let model = LatentOde::new(3, 4, 8, 999);
991        let obs: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64 * 0.1; 3]).collect();
992        let z = model.encode(&obs);
993        assert!(
994            z.iter().all(|v| v.is_finite()),
995            "Encoded latent contains non-finite"
996        );
997    }
998
999    #[test]
1000    fn test_latent_ode_round_trip_shape() {
1001        let model = LatentOde::new(2, 2, 4, 77);
1002        let obs = vec![vec![1.0, 0.0], vec![0.8, 0.1]];
1003        let z = model.encode(&obs);
1004        let recon = model.decode_single(&z);
1005        assert_eq!(recon.len(), 2);
1006    }
1007
1008    // ── TimeSeriesOde ─────────────────────────────────────────────────────────
1009
1010    #[test]
1011    fn test_time_series_ode_predict_shape() {
1012        let func = NeuralOdeFunc::new(2, 4, 13);
1013        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1014        let times = vec![0.0, 0.5, 1.0];
1015        let obs = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.8, 0.2]];
1016        let ts = TimeSeriesOde::new(times, obs, solver, 0.01, 0);
1017        let pred = ts.predict(1.5);
1018        assert_eq!(pred.len(), 2);
1019    }
1020
1021    #[test]
1022    fn test_time_series_ode_loss_nonnegative() {
1023        let func = NeuralOdeFunc::new(2, 4, 14);
1024        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1025        let times = vec![0.0, 0.5, 1.0];
1026        let obs = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.8, 0.2]];
1027        let ts = TimeSeriesOde::new(times, obs, solver, 0.01, 0);
1028        assert!(ts.compute_loss(0.05) >= 0.0);
1029    }
1030
1031    #[test]
1032    fn test_time_series_ode_fit_records_loss() {
1033        let func = NeuralOdeFunc::new(1, 4, 15);
1034        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1035        let times = vec![0.0, 0.1, 0.2, 0.3];
1036        let obs: Vec<Vec<f64>> = (0..4).map(|i| vec![(-(i as f64) * 0.1).exp()]).collect();
1037        let mut ts = TimeSeriesOde::new(times, obs, solver, 0.001, 5);
1038        ts.fit();
1039        assert_eq!(ts.loss_history.len(), 5);
1040    }
1041
1042    #[test]
1043    fn test_time_series_ode_predict_finite() {
1044        let func = NeuralOdeFunc::new(2, 4, 16);
1045        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1046        let times = vec![0.0, 0.5];
1047        let obs = vec![vec![1.0, 0.0], vec![0.9, -0.1]];
1048        let ts = TimeSeriesOde::new(times, obs, solver, 0.01, 0);
1049        let pred = ts.predict(0.3);
1050        assert!(pred.iter().all(|v| v.is_finite()));
1051    }
1052
1053    #[test]
1054    fn test_time_series_ode_empty() {
1055        let func = NeuralOdeFunc::new(2, 4, 17);
1056        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1057        let ts = TimeSeriesOde::new(vec![], vec![], solver, 0.01, 0);
1058        let pred = ts.predict(1.0);
1059        assert!(pred.is_empty());
1060        assert_eq!(ts.compute_loss(0.1), 0.0);
1061    }
1062
1063    // ── Integration accuracy tests ────────────────────────────────────────────
1064
1065    #[test]
1066    fn test_rk4_logistic_growth() {
1067        // dy/dt = y(1-y), y(0)=0.1 → y(t) = 1/(1 + 9*exp(-t))
1068        let f = |_t: f64, y: &[f64]| vec![y[0] * (1.0 - y[0])];
1069        let mut y = vec![0.1];
1070        let dt = 0.01;
1071        let steps = 200;
1072        for i in 0..steps {
1073            y = rk4_step(&f, i as f64 * dt, &y, dt);
1074        }
1075        let t = 2.0_f64;
1076        let exact = 1.0 / (1.0 + 9.0 * (-t).exp());
1077        assert!(
1078            (y[0] - exact).abs() < 1e-5,
1079            "Logistic growth: got {}, expected {}",
1080            y[0],
1081            exact
1082        );
1083    }
1084
1085    #[test]
1086    fn test_rk4_accuracy_order() {
1087        // Compare RK4 error at h=0.1 vs h=0.05 on exponential decay
1088        // RK4 is 4th-order: error ~ h^4, so halving h reduces error by ~16x
1089        let f = |_t: f64, y: &[f64]| vec![-y[0]];
1090        let exact = (-1.0_f64).exp();
1091
1092        let y_h1 = {
1093            let mut y = vec![1.0];
1094            for i in 0..10 {
1095                y = rk4_step(&f, i as f64 * 0.1, &y, 0.1);
1096            }
1097            y[0]
1098        };
1099        let y_h2 = {
1100            let mut y = vec![1.0];
1101            for i in 0..20 {
1102                y = rk4_step(&f, i as f64 * 0.05, &y, 0.05);
1103            }
1104            y[0]
1105        };
1106        let err1 = (y_h1 - exact).abs();
1107        let err2 = (y_h2 - exact).abs();
1108        assert!(
1109            err2 < err1,
1110            "Smaller step should give smaller error: {} vs {}",
1111            err2,
1112            err1
1113        );
1114    }
1115
1116    #[test]
1117    fn test_rk4_system_energy_conservation() {
1118        // Harmonic oscillator: H = 0.5*(x^2 + v^2) = 0.5 (for x0=1, v0=0)
1119        // should be approximately conserved by RK4
1120        let f = |_t: f64, z: &[f64]| vec![z[1], -z[0]];
1121        let mut z = vec![1.0, 0.0];
1122        let dt = 0.001;
1123        let steps = 1000;
1124        for i in 0..steps {
1125            z = rk4_step(&f, i as f64 * dt, &z, dt);
1126        }
1127        let energy = 0.5 * (z[0].powi(2) + z[1].powi(2));
1128        assert!(
1129            (energy - 0.5).abs() < 1e-4,
1130            "Energy drift: {}",
1131            energy - 0.5
1132        );
1133    }
1134
1135    #[test]
1136    fn test_neural_ode_func_batch_consistency() {
1137        // forward(t, z) should give the same result when called twice
1138        let func = NeuralOdeFunc::new(3, 8, 42);
1139        let z = vec![0.1, -0.2, 0.3];
1140        let d1 = func.forward(0.5, &z);
1141        let d2 = func.forward(0.5, &z);
1142        assert_eq!(d1, d2, "forward must be deterministic");
1143    }
1144
1145    #[test]
1146    fn test_time_series_ode_fit_loss_finite() {
1147        let func = NeuralOdeFunc::new(1, 4, 18);
1148        let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1149        let times: Vec<f64> = (0..5).map(|i| i as f64 * 0.2).collect();
1150        let obs: Vec<Vec<f64>> = times.iter().map(|&t: &f64| vec![(-t).exp()]).collect();
1151        let mut ts = TimeSeriesOde::new(times, obs, solver, 0.001, 3);
1152        ts.fit();
1153        for &l in &ts.loss_history {
1154            assert!(l.is_finite(), "Loss is non-finite: {l}");
1155        }
1156    }
1157
1158    #[test]
1159    fn test_rk4_step_multidim() {
1160        // 5-dimensional decay: dy_i/dt = -i*y_i, y_i(0)=1
1161        let f = |_t: f64, y: &[f64]| (0..y.len()).map(|i| -(i as f64 + 1.0) * y[i]).collect();
1162        let y0: Vec<f64> = vec![1.0; 5];
1163        let mut y = y0.clone();
1164        let dt = 0.01;
1165        for k in 0..10 {
1166            y = rk4_step(&f, k as f64 * dt, &y, dt);
1167        }
1168        for (i, &yi) in y.iter().enumerate() {
1169            let exact = (-(i as f64 + 1.0) * 0.1).exp();
1170            assert!(
1171                (yi - exact).abs() < 1e-5,
1172                "dim {i}: got {yi}, expected {exact}"
1173            );
1174        }
1175    }
1176}