1#![allow(dead_code)]
16#![allow(clippy::too_many_arguments)]
17
18pub fn rk4_step(f: &dyn Fn(f64, &[f64]) -> Vec<f64>, t: f64, y: &[f64], h: f64) -> Vec<f64> {
33 let n = y.len();
34 let k1 = f(t, y);
35 let y2: Vec<f64> = (0..n).map(|i| y[i] + 0.5 * h * k1[i]).collect();
36 let k2 = f(t + 0.5 * h, &y2);
37 let y3: Vec<f64> = (0..n).map(|i| y[i] + 0.5 * h * k2[i]).collect();
38 let k3 = f(t + 0.5 * h, &y3);
39 let y4: Vec<f64> = (0..n).map(|i| y[i] + h * k3[i]).collect();
40 let k4 = f(t + h, &y4);
41 (0..n)
42 .map(|i| y[i] + (h / 6.0) * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]))
43 .collect()
44}
45
46pub fn dopri5_step(
65 f: &dyn Fn(f64, &[f64]) -> Vec<f64>,
66 t: f64,
67 y: &[f64],
68 h: f64,
69 rtol: f64,
70 atol: f64,
71) -> (Vec<f64>, Vec<f64>, f64) {
72 let n = y.len();
73 let c2 = 1.0 / 5.0;
75 let c3 = 3.0 / 10.0;
76 let c4 = 4.0 / 5.0;
77 let c5 = 8.0 / 9.0;
78
79 let k1 = f(t, y);
80
81 let y2: Vec<f64> = (0..n).map(|i| y[i] + h * (1.0 / 5.0) * k1[i]).collect();
82 let k2 = f(t + c2 * h, &y2);
83
84 let y3: Vec<f64> = (0..n)
85 .map(|i| y[i] + h * ((3.0 / 40.0) * k1[i] + (9.0 / 40.0) * k2[i]))
86 .collect();
87 let k3 = f(t + c3 * h, &y3);
88
89 let y4: Vec<f64> = (0..n)
90 .map(|i| y[i] + h * ((44.0 / 45.0) * k1[i] - (56.0 / 15.0) * k2[i] + (32.0 / 9.0) * k3[i]))
91 .collect();
92 let k4 = f(t + c4 * h, &y4);
93
94 let y5: Vec<f64> = (0..n)
95 .map(|i| {
96 y[i] + h
97 * ((19372.0 / 6561.0) * k1[i] - (25360.0 / 2187.0) * k2[i]
98 + (64448.0 / 6561.0) * k3[i]
99 - (212.0 / 729.0) * k4[i])
100 })
101 .collect();
102 let k5 = f(t + c5 * h, &y5);
103
104 let y6: Vec<f64> = (0..n)
105 .map(|i| {
106 y[i] + h
107 * ((9017.0 / 3168.0) * k1[i] - (355.0 / 33.0) * k2[i]
108 + (46732.0 / 5247.0) * k3[i]
109 + (49.0 / 176.0) * k4[i]
110 - (5103.0 / 18656.0) * k5[i])
111 })
112 .collect();
113 let k6 = f(t + h, &y6);
114
115 let y_high: Vec<f64> = (0..n)
117 .map(|i| {
118 y[i] + h
119 * ((35.0 / 384.0) * k1[i] + (500.0 / 1113.0) * k3[i] + (125.0 / 192.0) * k4[i]
120 - (2187.0 / 6784.0) * k5[i]
121 + (11.0 / 84.0) * k6[i])
122 })
123 .collect();
124
125 let k7 = f(t + h, &y_high);
127
128 let y_low: Vec<f64> = (0..n)
133 .map(|i| {
134 let err_i = h
135 * ((71.0 / 57600.0) * k1[i] - (71.0 / 16695.0) * k3[i] + (71.0 / 1920.0) * k4[i]
136 - (17253.0 / 339200.0) * k5[i]
137 + (22.0 / 525.0) * k6[i]
138 - (1.0 / 40.0) * k7[i]);
139 y_high[i] - err_i
140 })
141 .collect();
142
143 let err_sq: f64 = (0..n)
145 .map(|i| {
146 let sc = atol + rtol * y[i].abs().max(y_high[i].abs());
147 let e = y_high[i] - y_low[i];
148 (e / sc).powi(2)
149 })
150 .sum::<f64>()
151 / n as f64;
152 let error_norm = err_sq.sqrt();
153
154 (y_high, y_low, error_norm)
155}
156
157fn tanh_vec(v: &[f64]) -> Vec<f64> {
163 v.iter().map(|x| x.tanh()).collect()
164}
165
166fn dense_tanh(input: &[f64], w: &[f64], b: &[f64], out: usize) -> Vec<f64> {
170 let inp = input.len();
171 (0..out)
172 .map(|i| {
173 let sum: f64 = (0..inp).map(|j| w[i * inp + j] * input[j]).sum::<f64>() + b[i];
174 sum.tanh()
175 })
176 .collect()
177}
178
179fn dense_linear(input: &[f64], w: &[f64], b: &[f64], out: usize) -> Vec<f64> {
181 let inp = input.len();
182 (0..out)
183 .map(|i| (0..inp).map(|j| w[i * inp + j] * input[j]).sum::<f64>() + b[i])
184 .collect()
185}
186
187#[derive(Debug, Clone)]
196pub struct NeuralOdeFunc {
197 pub input_size: usize,
199 pub hidden_size: usize,
201 pub weights_in: Vec<f64>,
203 pub bias_in: Vec<f64>,
205 pub weights_hidden: Vec<f64>,
207 pub bias_hidden: Vec<f64>,
209 pub weights_out: Vec<f64>,
211 pub bias_out: Vec<f64>,
213}
214
215impl NeuralOdeFunc {
216 pub fn new(input_size: usize, hidden_size: usize, seed: u64) -> Self {
219 let mut rng_state = seed;
220 let mut next = move || -> f64 {
221 rng_state = rng_state
222 .wrapping_mul(6364136223846793005)
223 .wrapping_add(1442695040888963407);
224 let bits = (rng_state >> 11) as f64;
226 (bits / (1u64 << 53) as f64) * 0.2 - 0.1
227 };
228
229 let wi: Vec<f64> = (0..hidden_size * (input_size + 1))
231 .map(|_| next())
232 .collect();
233 let bi: Vec<f64> = (0..hidden_size).map(|_| next()).collect();
234 let wh: Vec<f64> = (0..hidden_size * hidden_size).map(|_| next()).collect();
235 let bh: Vec<f64> = (0..hidden_size).map(|_| next()).collect();
236 let wo: Vec<f64> = (0..input_size * hidden_size).map(|_| next()).collect();
237 let bo: Vec<f64> = (0..input_size).map(|_| next()).collect();
238
239 Self {
240 input_size,
241 hidden_size,
242 weights_in: wi,
243 bias_in: bi,
244 weights_hidden: wh,
245 bias_hidden: bh,
246 weights_out: wo,
247 bias_out: bo,
248 }
249 }
250
251 pub fn forward(&self, t: f64, z: &[f64]) -> Vec<f64> {
256 let mut aug = Vec::with_capacity(self.input_size + 1);
258 aug.extend_from_slice(z);
259 aug.push(t);
260
261 let h1 = dense_tanh(&aug, &self.weights_in, &self.bias_in, self.hidden_size);
262 let h2 = dense_tanh(
263 &h1,
264 &self.weights_hidden,
265 &self.bias_hidden,
266 self.hidden_size,
267 );
268 dense_linear(&h2, &self.weights_out, &self.bias_out, self.input_size)
269 }
270
271 pub fn jvp(&self, t: f64, z: &[f64], v: &[f64], eps: f64) -> Vec<f64> {
275 let f0 = self.forward(t, z);
276 let z_plus: Vec<f64> = z
277 .iter()
278 .zip(v.iter())
279 .map(|(zi, vi)| zi + eps * vi)
280 .collect();
281 let f_plus = self.forward(t, &z_plus);
282 f_plus
283 .iter()
284 .zip(f0.iter())
285 .map(|(fp, f0i)| (fp - f0i) / eps)
286 .collect()
287 }
288
289 pub fn params_flat(&self) -> Vec<f64> {
293 let mut p = Vec::with_capacity(self.n_params());
294 p.extend_from_slice(&self.weights_in);
295 p.extend_from_slice(&self.bias_in);
296 p.extend_from_slice(&self.weights_hidden);
297 p.extend_from_slice(&self.bias_hidden);
298 p.extend_from_slice(&self.weights_out);
299 p.extend_from_slice(&self.bias_out);
300 p
301 }
302
303 pub fn n_params(&self) -> usize {
305 self.weights_in.len()
306 + self.bias_in.len()
307 + self.weights_hidden.len()
308 + self.bias_hidden.len()
309 + self.weights_out.len()
310 + self.bias_out.len()
311 }
312
313 pub fn set_params_flat(&mut self, params: &[f64]) {
315 let mut off = 0;
316 let wi_len = self.weights_in.len();
317 self.weights_in.copy_from_slice(¶ms[off..off + wi_len]);
318 off += wi_len;
319 let bi_len = self.bias_in.len();
320 self.bias_in.copy_from_slice(¶ms[off..off + bi_len]);
321 off += bi_len;
322 let wh_len = self.weights_hidden.len();
323 self.weights_hidden
324 .copy_from_slice(¶ms[off..off + wh_len]);
325 off += wh_len;
326 let bh_len = self.bias_hidden.len();
327 self.bias_hidden.copy_from_slice(¶ms[off..off + bh_len]);
328 off += bh_len;
329 let wo_len = self.weights_out.len();
330 self.weights_out.copy_from_slice(¶ms[off..off + wo_len]);
331 off += wo_len;
332 let bo_len = self.bias_out.len();
333 self.bias_out.copy_from_slice(¶ms[off..off + bo_len]);
334 let _ = off + bo_len;
335 }
336
337 pub fn param_grad_contrib(&self, t: f64, z: &[f64], adj: &[f64], eps: f64) -> Vec<f64> {
342 let n_p = self.n_params();
343 let params = self.params_flat();
344 let mut grad = vec![0.0_f64; n_p];
345 let mut tmp = self.clone();
346 for j in 0..n_p {
347 let mut p_plus = params.clone();
348 let mut p_minus = params.clone();
349 p_plus[j] += eps;
350 p_minus[j] -= eps;
351 tmp.set_params_flat(&p_plus);
352 let f_plus = tmp.forward(t, z);
353 tmp.set_params_flat(&p_minus);
354 let f_minus = tmp.forward(t, z);
355 grad[j] = adj
356 .iter()
357 .zip(f_plus.iter().zip(f_minus.iter()))
358 .map(|(&ai, (&fp, &fm))| ai * (fp - fm) / (2.0 * eps))
359 .sum();
360 }
361 grad
362 }
363}
364
365#[derive(Debug, Clone)]
372pub struct NeuralOdeSolver {
373 pub func: NeuralOdeFunc,
375 pub rtol: f64,
377 pub atol: f64,
379}
380
381impl NeuralOdeSolver {
382 pub fn new(func: NeuralOdeFunc, rtol: f64, atol: f64) -> Self {
384 Self { func, rtol, atol }
385 }
386
387 pub fn solve_rk4(&self, z0: &[f64], t0: f64, t1: f64, dt: f64) -> Vec<f64> {
391 let mut z = z0.to_vec();
392 let mut t = t0;
393 let forward = |t: f64, y: &[f64]| self.func.forward(t, y);
394 while t < t1 - 1e-12 {
395 let h = dt.min(t1 - t);
396 z = rk4_step(&forward, t, &z, h);
397 t += h;
398 }
399 z
400 }
401
402 pub fn solve_dopri5(&self, z0: &[f64], t0: f64, t1: f64, dt_init: f64) -> Vec<f64> {
406 let mut z = z0.to_vec();
407 let mut t = t0;
408 let mut h = dt_init;
409 let max_steps = 100_000usize;
410 let forward = |t: f64, y: &[f64]| self.func.forward(t, y);
411 for _ in 0..max_steps {
412 if t >= t1 - 1e-12 {
413 break;
414 }
415 h = h.min(t1 - t);
416 let (y_high, _y_low, err) = dopri5_step(&forward, t, &z, h, self.rtol, self.atol);
417 if err <= 1.0 || h <= 1e-10 {
418 z = y_high;
419 t += h;
420 }
421 let factor = if err < 1e-14 {
423 5.0
424 } else {
425 0.9 * (1.0 / err).powf(0.2)
426 };
427 h = (h * factor.clamp(0.1, 5.0)).min(t1 - t);
428 }
429 z
430 }
431
432 pub fn solve_rk4_trajectory(&self, z0: &[f64], ts: &[f64], dt: f64) -> Vec<Vec<f64>> {
437 if ts.is_empty() {
438 return vec![];
439 }
440 let mut result = Vec::with_capacity(ts.len());
441 let mut z = z0.to_vec();
442 let mut t = ts[0];
443 result.push(z.clone());
444 let forward = |t: f64, y: &[f64]| self.func.forward(t, y);
445 for &t_next in ts.iter().skip(1) {
446 while t < t_next - 1e-12 {
447 let h = dt.min(t_next - t);
448 z = rk4_step(&forward, t, &z, h);
449 t += h;
450 }
451 result.push(z.clone());
452 }
453 result
454 }
455}
456
457#[derive(Debug, Clone)]
466pub struct AdjointMethod {
467 pub augmented_state: Vec<f64>,
469 pub state_dim: usize,
471}
472
473impl AdjointMethod {
474 pub fn new(state_dim: usize) -> Self {
476 Self {
477 augmented_state: vec![0.0; state_dim * 2],
478 state_dim,
479 }
480 }
481
482 pub fn backward(&self, loss_grad: &[f64]) -> Vec<f64> {
492 loss_grad.iter().map(|&g| -g).collect()
494 }
495
496 pub fn run(
512 &mut self,
513 solver: &NeuralOdeSolver,
514 z_final: &[f64],
515 loss_grad: &[f64],
516 t0: f64,
517 t1: f64,
518 dt: f64,
519 ) -> (Vec<f64>, Vec<f64>) {
520 let n = self.state_dim;
521 let eps = 1e-5;
522 let h_step = dt.abs().max(1e-10);
523
524 let neg_f = |tc: f64, y: &[f64]| -> Vec<f64> {
527 solver.func.forward(tc, y).into_iter().map(|v| -v).collect()
528 };
529 let mut z_bwd = z_final.to_vec();
530 let mut t_cur = t1;
531 let mut times: Vec<f64> = vec![t_cur];
532 let mut states: Vec<Vec<f64>> = vec![z_bwd.clone()];
533 while t_cur > t0 + 1e-12 {
534 let h_bwd = h_step.min(t_cur - t0);
535 z_bwd = rk4_step(&neg_f, t_cur, &z_bwd, h_bwd);
536 t_cur -= h_bwd;
537 times.push(t_cur);
538 states.push(z_bwd.clone());
539 }
540 times.reverse();
542 states.reverse();
543
544 let n_params = solver.func.n_params();
546 let mut adj = loss_grad.to_vec();
547 let mut grad_params = vec![0.0_f64; n_params];
548 let n_ckpt = times.len();
549
550 for ck in (1..n_ckpt).rev() {
551 let t_hi = times[ck];
552 let t_lo = times[ck - 1];
553 let z_ck = &states[ck];
554 let h_abs = (t_hi - t_lo).abs().max(1e-14);
555
556 let pg = solver.func.param_grad_contrib(t_hi, z_ck, &adj, eps);
559 for (g, &pg_j) in grad_params.iter_mut().zip(pg.iter()) {
560 *g += h_abs * pg_j;
561 }
562
563 let jvp1 = solver.func.jvp(t_hi, z_ck, &adj, eps);
565 let a2: Vec<f64> = (0..n).map(|i| adj[i] + 0.5 * h_abs * (-jvp1[i])).collect();
566 let jvp2 = solver.func.jvp(t_hi - 0.5 * h_abs, z_ck, &a2, eps);
567 let a3: Vec<f64> = (0..n).map(|i| adj[i] + 0.5 * h_abs * (-jvp2[i])).collect();
568 let jvp3 = solver.func.jvp(t_hi - 0.5 * h_abs, z_ck, &a3, eps);
569 let a4: Vec<f64> = (0..n).map(|i| adj[i] + h_abs * (-jvp3[i])).collect();
570 let jvp4 = solver.func.jvp(t_lo, z_ck, &a4, eps);
571 adj = (0..n)
572 .map(|i| {
573 adj[i] + (h_abs / 6.0) * (-jvp1[i] - 2.0 * jvp2[i] - 2.0 * jvp3[i] - jvp4[i])
574 })
575 .collect();
576 }
577
578 (adj, grad_params)
579 }
580}
581
582#[derive(Debug, Clone)]
592pub struct LatentOde {
593 pub latent_dim: usize,
595 pub obs_dim: usize,
597 pub encoder_weights: Vec<f64>,
599 pub encoder_bias: Vec<f64>,
601 pub dynamics: NeuralOdeFunc,
603 pub decoder_weights: Vec<f64>,
605 pub decoder_bias: Vec<f64>,
607}
608
609impl LatentOde {
610 pub fn new(obs_dim: usize, latent_dim: usize, hidden_size: usize, seed: u64) -> Self {
612 let mut s = seed;
614 let mut next = move || -> f64 {
615 s = s
616 .wrapping_mul(6364136223846793005)
617 .wrapping_add(1442695040888963407);
618 ((s >> 11) as f64 / (1u64 << 53) as f64) * 0.2 - 0.1
619 };
620
621 let ew: Vec<f64> = (0..latent_dim * obs_dim).map(|_| next()).collect();
622 let eb: Vec<f64> = (0..latent_dim).map(|_| next()).collect();
623 let dw: Vec<f64> = (0..obs_dim * latent_dim).map(|_| next()).collect();
624 let db: Vec<f64> = (0..obs_dim).map(|_| next()).collect();
625
626 Self {
627 latent_dim,
628 obs_dim,
629 encoder_weights: ew,
630 encoder_bias: eb,
631 dynamics: NeuralOdeFunc::new(latent_dim, hidden_size, seed.wrapping_add(1)),
632 decoder_weights: dw,
633 decoder_bias: db,
634 }
635 }
636
637 pub fn encode(&self, obs: &[Vec<f64>]) -> Vec<f64> {
641 if obs.is_empty() {
642 return vec![0.0; self.latent_dim];
643 }
644 let n = obs.len() as f64;
646 let avg: Vec<f64> = (0..self.obs_dim)
647 .map(|j| {
648 obs.iter()
649 .map(|o| o.get(j).copied().unwrap_or(0.0))
650 .sum::<f64>()
651 / n
652 })
653 .collect();
654 dense_tanh(
656 &avg,
657 &self.encoder_weights,
658 &self.encoder_bias,
659 self.latent_dim,
660 )
661 }
662
663 pub fn decode_single(&self, z: &[f64]) -> Vec<f64> {
665 dense_linear(z, &self.decoder_weights, &self.decoder_bias, self.obs_dim)
666 }
667
668 pub fn decode(&self, z: &[f64], t0: f64, ts: &[f64], dt: f64) -> Vec<Vec<f64>> {
672 let solver = NeuralOdeSolver::new(self.dynamics.clone(), 1e-3, 1e-6);
673 let states = solver.solve_rk4_trajectory(
674 z,
675 &{
676 let mut times = vec![t0];
677 times.extend_from_slice(ts);
678 times
679 },
680 dt,
681 );
682 states.iter().map(|s| self.decode_single(s)).collect()
683 }
684}
685
686#[derive(Debug, Clone)]
693pub struct TimeSeriesOde {
694 pub times: Vec<f64>,
696 pub observations: Vec<Vec<f64>>,
698 pub solver: NeuralOdeSolver,
700 pub learning_rate: f64,
702 pub n_iter: usize,
704 pub loss_history: Vec<f64>,
706}
707
708impl TimeSeriesOde {
709 pub fn new(
711 times: Vec<f64>,
712 observations: Vec<Vec<f64>>,
713 solver: NeuralOdeSolver,
714 learning_rate: f64,
715 n_iter: usize,
716 ) -> Self {
717 Self {
718 times,
719 observations,
720 solver,
721 learning_rate,
722 n_iter,
723 loss_history: Vec::new(),
724 }
725 }
726
727 pub fn fit(&mut self) {
733 let dt = if self.times.len() > 1 {
734 (self.times[self.times.len() - 1] - self.times[0]) / (self.times.len() as f64 * 10.0)
735 } else {
736 0.01
737 };
738
739 for _iter in 0..self.n_iter {
740 let loss = self.compute_loss(dt);
742 self.loss_history.push(loss);
743
744 let grad_scale = self.learning_rate * 0.01;
747 for b in &mut self.solver.func.bias_out {
748 *b -= grad_scale * (*b).signum();
749 }
750 }
751 }
752
753 pub fn compute_loss(&self, dt: f64) -> f64 {
755 if self.times.is_empty() || self.observations.is_empty() {
756 return 0.0;
757 }
758 let z0 = self.observations[0].clone();
759 let states = self.solver.solve_rk4_trajectory(&z0, &self.times, dt);
760 let mut mse = 0.0;
761 let mut count = 0usize;
762 for (pred, obs) in states.iter().zip(self.observations.iter()) {
763 for (p, o) in pred.iter().zip(obs.iter()) {
764 mse += (p - o).powi(2);
765 count += 1;
766 }
767 }
768 if count > 0 { mse / count as f64 } else { 0.0 }
769 }
770
771 pub fn predict(&self, t: f64) -> Vec<f64> {
775 if self.times.is_empty() || self.observations.is_empty() {
776 return vec![];
777 }
778 let z0 = self.observations[0].clone();
779 let t0 = self.times[0];
780 let dt = (t - t0).abs() / 100.0_f64.max(1.0);
781 self.solver.solve_rk4(&z0, t0, t, dt.max(1e-4))
782 }
783}
784
785#[cfg(test)]
790mod tests {
791 use super::*;
792
793 #[test]
796 fn test_rk4_exponential_decay() {
797 let f = |_t: f64, y: &[f64]| vec![-y[0]];
799 let y0 = vec![1.0];
800 let y1 = rk4_step(&f, 0.0, &y0, 0.1);
801 let exact = (-0.1_f64).exp();
802 assert!(
803 (y1[0] - exact).abs() < 1e-6,
804 "RK4 decay: got {}, expected {}",
805 y1[0],
806 exact
807 );
808 }
809
810 #[test]
811 fn test_rk4_harmonic_oscillator() {
812 let f = |_t: f64, z: &[f64]| vec![z[1], -z[0]];
814 let z0 = vec![1.0, 0.0]; let mut z = z0.clone();
816 let dt = 0.01;
817 let steps = 100; for i in 0..steps {
819 z = rk4_step(&f, i as f64 * dt, &z, dt);
820 }
821 let t = 1.0_f64;
822 let exact_x = t.cos();
823 assert!(
824 (z[0] - exact_x).abs() < 1e-5,
825 "Harmonic oscillator x: got {}",
826 z[0]
827 );
828 }
829
830 #[test]
831 fn test_rk4_constant_ode() {
832 let f = |_t: f64, _y: &[f64]| vec![2.0];
834 let y = rk4_step(&f, 0.0, &[0.0], 1.0);
835 assert!((y[0] - 2.0).abs() < 1e-12);
836 }
837
838 #[test]
839 fn test_rk4_zero_step() {
840 let f = |_t: f64, y: &[f64]| vec![-y[0]];
841 let y0 = vec![3.0];
842 let y1 = rk4_step(&f, 0.0, &y0, 0.0);
843 assert!((y1[0] - 3.0).abs() < 1e-15);
844 }
845
846 #[test]
847 fn test_rk4_linear_ode() {
848 let f = |t: f64, _y: &[f64]| vec![t];
850 let mut y = vec![0.0];
851 let dt = 0.01;
852 for i in 0..200 {
853 y = rk4_step(&f, i as f64 * dt, &y, dt);
854 }
855 assert!((y[0] - 2.0).abs() < 1e-6, "Linear ODE: got {}", y[0]);
856 }
857
858 #[test]
859 fn test_rk4_2d_decoupled() {
860 let f = |_t: f64, y: &[f64]| vec![-y[0], -2.0 * y[1]];
862 let mut z = vec![1.0_f64, 1.0_f64];
863 let dt = 0.01;
864 for i in 0..50 {
865 z = rk4_step(&f, i as f64 * dt, &z, dt);
866 }
867 let t = 0.5_f64;
868 assert!((z[0] - (-t).exp()).abs() < 1e-5, "y1: {}", z[0]);
869 assert!((z[1] - (-2.0 * t).exp()).abs() < 1e-5, "y2: {}", z[1]);
870 }
871
872 #[test]
875 fn test_dopri5_returns_three_values() {
876 let f = |_t: f64, y: &[f64]| vec![-y[0]];
877 let (yh, yl, err) = dopri5_step(&f, 0.0, &[1.0], 0.1, 1e-3, 1e-6);
878 assert_eq!(yh.len(), 1);
879 assert_eq!(yl.len(), 1);
880 assert!(err.is_finite());
881 }
882
883 #[test]
884 fn test_dopri5_exponential_accuracy() {
885 let f = |_t: f64, y: &[f64]| vec![-y[0]];
886 let (yh, _yl, _err) = dopri5_step(&f, 0.0, &[1.0], 0.1, 1e-6, 1e-9);
887 let exact = (-0.1_f64).exp();
888 assert!(
889 (yh[0] - exact).abs() < 1e-8,
890 "DOPRI5 accuracy: {}",
891 (yh[0] - exact).abs()
892 );
893 }
894
895 #[test]
896 fn test_dopri5_zero_step_size() {
897 let f = |_t: f64, y: &[f64]| vec![-y[0]];
898 let (yh, yl, err) = dopri5_step(&f, 0.0, &[1.0], 0.0, 1e-3, 1e-6);
899 assert!((yh[0] - 1.0).abs() < 1e-12);
900 assert!((yl[0] - 1.0).abs() < 1e-12);
901 assert!(err < 1e-10);
902 }
903
904 #[test]
907 fn test_neural_ode_func_forward_shape() {
908 let func = NeuralOdeFunc::new(3, 8, 42);
909 let z = vec![1.0, 0.0, -1.0];
910 let dz = func.forward(0.0, &z);
911 assert_eq!(dz.len(), 3);
912 }
913
914 #[test]
915 fn test_neural_ode_func_forward_finite() {
916 let func = NeuralOdeFunc::new(4, 16, 1234);
917 let z = vec![0.5, -0.3, 1.2, -0.1];
918 let dz = func.forward(1.0, &z);
919 for &v in &dz {
920 assert!(
921 v.is_finite(),
922 "NeuralOdeFunc output contains non-finite: {v}"
923 );
924 }
925 }
926
927 #[test]
928 fn test_neural_ode_func_deterministic() {
929 let f1 = NeuralOdeFunc::new(2, 4, 99);
930 let f2 = NeuralOdeFunc::new(2, 4, 99);
931 let z = vec![0.1, 0.2];
932 assert_eq!(f1.forward(0.0, &z), f2.forward(0.0, &z));
933 }
934
935 #[test]
936 fn test_neural_ode_func_different_seeds_differ() {
937 let f1 = NeuralOdeFunc::new(2, 8, 1);
938 let f2 = NeuralOdeFunc::new(2, 8, 2);
939 let z = vec![1.0, 1.0];
940 let d1 = f1.forward(0.0, &z);
941 let d2 = f2.forward(0.0, &z);
942 let diff: f64 = d1.iter().zip(d2.iter()).map(|(a, b)| (a - b).abs()).sum();
943 assert!(
944 diff > 1e-10,
945 "Different seeds should give different outputs"
946 );
947 }
948
949 #[test]
950 fn test_neural_ode_func_jvp_shape() {
951 let func = NeuralOdeFunc::new(3, 6, 7);
952 let z = vec![0.0, 1.0, -1.0];
953 let v = vec![1.0, 0.0, 0.0];
954 let jvp = func.jvp(0.5, &z, &v, 1e-5);
955 assert_eq!(jvp.len(), 3);
956 }
957
958 #[test]
961 fn test_solver_rk4_output_shape() {
962 let func = NeuralOdeFunc::new(2, 4, 0);
963 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
964 let z0 = vec![1.0, 0.0];
965 let z1 = solver.solve_rk4(&z0, 0.0, 1.0, 0.1);
966 assert_eq!(z1.len(), 2);
967 }
968
969 #[test]
970 fn test_solver_rk4_zero_integration() {
971 let func = NeuralOdeFunc::new(2, 4, 5);
973 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
974 let z0 = vec![1.0, 2.0];
975 let z1 = solver.solve_rk4(&z0, 0.0, 0.0, 0.1);
976 for (a, b) in z0.iter().zip(z1.iter()) {
978 assert!((a - b).abs() < 1e-12);
979 }
980 }
981
982 #[test]
983 fn test_solver_rk4_finite_output() {
984 let func = NeuralOdeFunc::new(3, 8, 100);
985 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
986 let z0 = vec![0.1, -0.2, 0.3];
987 let z1 = solver.solve_rk4(&z0, 0.0, 0.5, 0.05);
988 for &v in &z1 {
989 assert!(v.is_finite());
990 }
991 }
992
993 #[test]
994 fn test_solver_trajectory_length() {
995 let func = NeuralOdeFunc::new(2, 4, 3);
996 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
997 let z0 = vec![1.0, 0.0];
998 let ts = vec![0.0, 0.25, 0.5, 0.75, 1.0];
999 let traj = solver.solve_rk4_trajectory(&z0, &ts, 0.05);
1000 assert_eq!(traj.len(), ts.len());
1001 }
1002
1003 #[test]
1004 fn test_solver_dopri5_output_shape() {
1005 let func = NeuralOdeFunc::new(2, 4, 42);
1006 let solver = NeuralOdeSolver::new(func, 1e-4, 1e-7);
1007 let z0 = vec![1.0, 0.5];
1008 let z1 = solver.solve_dopri5(&z0, 0.0, 1.0, 0.1);
1009 assert_eq!(z1.len(), 2);
1010 }
1011
1012 #[test]
1013 fn test_solver_dopri5_finite_output() {
1014 let func = NeuralOdeFunc::new(3, 6, 77);
1015 let solver = NeuralOdeSolver::new(func, 1e-4, 1e-7);
1016 let z0 = vec![0.0, 0.5, 1.0];
1017 let z1 = solver.solve_dopri5(&z0, 0.0, 0.5, 0.1);
1018 for &v in &z1 {
1019 assert!(v.is_finite(), "DOPRI5 produced non-finite: {v}");
1020 }
1021 }
1022
1023 #[test]
1026 fn test_adjoint_backward_shape() {
1027 let adj = AdjointMethod::new(4);
1028 let loss_grad = vec![1.0, -1.0, 0.5, -0.5];
1029 let grad = adj.backward(&loss_grad);
1030 assert_eq!(grad.len(), 4);
1031 }
1032
1033 #[test]
1034 fn test_adjoint_backward_negation() {
1035 let adj = AdjointMethod::new(3);
1036 let loss_grad = vec![2.0, -3.0, 1.0];
1037 let grad = adj.backward(&loss_grad);
1038 assert_eq!(grad, vec![-2.0, 3.0, -1.0]);
1039 }
1040
1041 #[test]
1042 fn test_adjoint_run_returns_correct_shapes() {
1043 let func = NeuralOdeFunc::new(2, 4, 11);
1044 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1045 let mut adj = AdjointMethod::new(2);
1046 let z_final = vec![0.5, -0.5];
1047 let loss_grad = vec![1.0, 0.0];
1048 let (grad_z0, grad_params) = adj.run(&solver, &z_final, &loss_grad, 0.0, 1.0, 0.1);
1049 assert_eq!(grad_z0.len(), 2);
1050 assert!(!grad_params.is_empty());
1051 }
1052
1053 #[test]
1054 fn test_adjoint_run_finite() {
1055 let func = NeuralOdeFunc::new(2, 4, 22);
1056 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1057 let mut adj = AdjointMethod::new(2);
1058 let z_final = vec![1.0, 1.0];
1059 let loss_grad = vec![0.1, -0.1];
1060 let (g, _) = adj.run(&solver, &z_final, &loss_grad, 0.0, 1.0, 0.1);
1061 for &v in &g {
1062 assert!(v.is_finite());
1063 }
1064 }
1065
1066 #[test]
1069 fn test_latent_ode_encode_shape() {
1070 let model = LatentOde::new(4, 2, 8, 55);
1071 let obs = vec![vec![1.0, 0.0, -1.0, 0.5], vec![0.5, 0.1, -0.5, 0.3]];
1072 let z = model.encode(&obs);
1073 assert_eq!(z.len(), 2);
1074 }
1075
1076 #[test]
1077 fn test_latent_ode_encode_empty() {
1078 let model = LatentOde::new(3, 2, 4, 1);
1079 let z = model.encode(&[]);
1080 assert_eq!(z.len(), 2);
1081 assert!(z.iter().all(|&v| v == 0.0));
1082 }
1083
1084 #[test]
1085 fn test_latent_ode_decode_single_shape() {
1086 let model = LatentOde::new(4, 2, 6, 88);
1087 let z = vec![0.5, -0.3];
1088 let obs = model.decode_single(&z);
1089 assert_eq!(obs.len(), 4);
1090 }
1091
1092 #[test]
1093 fn test_latent_ode_decode_trajectory_length() {
1094 let model = LatentOde::new(3, 2, 4, 33);
1095 let z = vec![0.1, 0.2];
1096 let ts = vec![0.1, 0.2, 0.5, 1.0];
1097 let preds = model.decode(&z, 0.0, &ts, 0.05);
1098 assert_eq!(preds.len(), ts.len() + 1);
1100 }
1101
1102 #[test]
1103 fn test_latent_ode_encode_finite() {
1104 let model = LatentOde::new(3, 4, 8, 999);
1105 let obs: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64 * 0.1; 3]).collect();
1106 let z = model.encode(&obs);
1107 assert!(
1108 z.iter().all(|v| v.is_finite()),
1109 "Encoded latent contains non-finite"
1110 );
1111 }
1112
1113 #[test]
1114 fn test_latent_ode_round_trip_shape() {
1115 let model = LatentOde::new(2, 2, 4, 77);
1116 let obs = vec![vec![1.0, 0.0], vec![0.8, 0.1]];
1117 let z = model.encode(&obs);
1118 let recon = model.decode_single(&z);
1119 assert_eq!(recon.len(), 2);
1120 }
1121
1122 #[test]
1125 fn test_time_series_ode_predict_shape() {
1126 let func = NeuralOdeFunc::new(2, 4, 13);
1127 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1128 let times = vec![0.0, 0.5, 1.0];
1129 let obs = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.8, 0.2]];
1130 let ts = TimeSeriesOde::new(times, obs, solver, 0.01, 0);
1131 let pred = ts.predict(1.5);
1132 assert_eq!(pred.len(), 2);
1133 }
1134
1135 #[test]
1136 fn test_time_series_ode_loss_nonnegative() {
1137 let func = NeuralOdeFunc::new(2, 4, 14);
1138 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1139 let times = vec![0.0, 0.5, 1.0];
1140 let obs = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.8, 0.2]];
1141 let ts = TimeSeriesOde::new(times, obs, solver, 0.01, 0);
1142 assert!(ts.compute_loss(0.05) >= 0.0);
1143 }
1144
1145 #[test]
1146 fn test_time_series_ode_fit_records_loss() {
1147 let func = NeuralOdeFunc::new(1, 4, 15);
1148 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1149 let times = vec![0.0, 0.1, 0.2, 0.3];
1150 let obs: Vec<Vec<f64>> = (0..4).map(|i| vec![(-(i as f64) * 0.1).exp()]).collect();
1151 let mut ts = TimeSeriesOde::new(times, obs, solver, 0.001, 5);
1152 ts.fit();
1153 assert_eq!(ts.loss_history.len(), 5);
1154 }
1155
1156 #[test]
1157 fn test_time_series_ode_predict_finite() {
1158 let func = NeuralOdeFunc::new(2, 4, 16);
1159 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1160 let times = vec![0.0, 0.5];
1161 let obs = vec![vec![1.0, 0.0], vec![0.9, -0.1]];
1162 let ts = TimeSeriesOde::new(times, obs, solver, 0.01, 0);
1163 let pred = ts.predict(0.3);
1164 assert!(pred.iter().all(|v| v.is_finite()));
1165 }
1166
1167 #[test]
1168 fn test_time_series_ode_empty() {
1169 let func = NeuralOdeFunc::new(2, 4, 17);
1170 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1171 let ts = TimeSeriesOde::new(vec![], vec![], solver, 0.01, 0);
1172 let pred = ts.predict(1.0);
1173 assert!(pred.is_empty());
1174 assert_eq!(ts.compute_loss(0.1), 0.0);
1175 }
1176
1177 #[test]
1180 fn test_rk4_logistic_growth() {
1181 let f = |_t: f64, y: &[f64]| vec![y[0] * (1.0 - y[0])];
1183 let mut y = vec![0.1];
1184 let dt = 0.01;
1185 let steps = 200;
1186 for i in 0..steps {
1187 y = rk4_step(&f, i as f64 * dt, &y, dt);
1188 }
1189 let t = 2.0_f64;
1190 let exact = 1.0 / (1.0 + 9.0 * (-t).exp());
1191 assert!(
1192 (y[0] - exact).abs() < 1e-5,
1193 "Logistic growth: got {}, expected {}",
1194 y[0],
1195 exact
1196 );
1197 }
1198
1199 #[test]
1200 fn test_rk4_accuracy_order() {
1201 let f = |_t: f64, y: &[f64]| vec![-y[0]];
1204 let exact = (-1.0_f64).exp();
1205
1206 let y_h1 = {
1207 let mut y = vec![1.0];
1208 for i in 0..10 {
1209 y = rk4_step(&f, i as f64 * 0.1, &y, 0.1);
1210 }
1211 y[0]
1212 };
1213 let y_h2 = {
1214 let mut y = vec![1.0];
1215 for i in 0..20 {
1216 y = rk4_step(&f, i as f64 * 0.05, &y, 0.05);
1217 }
1218 y[0]
1219 };
1220 let err1 = (y_h1 - exact).abs();
1221 let err2 = (y_h2 - exact).abs();
1222 assert!(
1223 err2 < err1,
1224 "Smaller step should give smaller error: {} vs {}",
1225 err2,
1226 err1
1227 );
1228 }
1229
1230 #[test]
1231 fn test_rk4_system_energy_conservation() {
1232 let f = |_t: f64, z: &[f64]| vec![z[1], -z[0]];
1235 let mut z = vec![1.0, 0.0];
1236 let dt = 0.001;
1237 let steps = 1000;
1238 for i in 0..steps {
1239 z = rk4_step(&f, i as f64 * dt, &z, dt);
1240 }
1241 let energy = 0.5 * (z[0].powi(2) + z[1].powi(2));
1242 assert!(
1243 (energy - 0.5).abs() < 1e-4,
1244 "Energy drift: {}",
1245 energy - 0.5
1246 );
1247 }
1248
1249 #[test]
1250 fn test_neural_ode_func_batch_consistency() {
1251 let func = NeuralOdeFunc::new(3, 8, 42);
1253 let z = vec![0.1, -0.2, 0.3];
1254 let d1 = func.forward(0.5, &z);
1255 let d2 = func.forward(0.5, &z);
1256 assert_eq!(d1, d2, "forward must be deterministic");
1257 }
1258
1259 #[test]
1260 fn test_time_series_ode_fit_loss_finite() {
1261 let func = NeuralOdeFunc::new(1, 4, 18);
1262 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1263 let times: Vec<f64> = (0..5).map(|i| i as f64 * 0.2).collect();
1264 let obs: Vec<Vec<f64>> = times.iter().map(|&t: &f64| vec![(-t).exp()]).collect();
1265 let mut ts = TimeSeriesOde::new(times, obs, solver, 0.001, 3);
1266 ts.fit();
1267 for &l in &ts.loss_history {
1268 assert!(l.is_finite(), "Loss is non-finite: {l}");
1269 }
1270 }
1271
1272 #[test]
1273 fn test_rk4_step_multidim() {
1274 let f = |_t: f64, y: &[f64]| (0..y.len()).map(|i| -(i as f64 + 1.0) * y[i]).collect();
1276 let y0: Vec<f64> = vec![1.0; 5];
1277 let mut y = y0.clone();
1278 let dt = 0.01;
1279 for k in 0..10 {
1280 y = rk4_step(&f, k as f64 * dt, &y, dt);
1281 }
1282 for (i, &yi) in y.iter().enumerate() {
1283 let exact = (-(i as f64 + 1.0) * 0.1).exp();
1284 assert!(
1285 (yi - exact).abs() < 1e-5,
1286 "dim {i}: got {yi}, expected {exact}"
1287 );
1288 }
1289 }
1290
1291 #[test]
1294 fn test_dopri5_error_estimate_order() {
1295 let f = |_t: f64, y: &[f64]| vec![y[0]];
1298 let rtol = 1e-12;
1299 let atol = 1e-12;
1300 let y0 = vec![1.0_f64];
1301
1302 let (y_big, _, _) = dopri5_step(&f, 0.0, &y0, 0.2, rtol, atol);
1303 let (y_small, _, _) = dopri5_step(&f, 0.0, &y0, 0.1, rtol, atol);
1304 let err_big = (y_big[0] - 0.2_f64.exp()).abs();
1305 let err_small = (y_small[0] - 0.1_f64.exp()).abs();
1306 let ratio = err_big / err_small.max(f64::MIN_POSITIVE);
1308 assert!(
1309 ratio > 10.0,
1310 "Expected ~32× error reduction when halving step; got ratio={ratio:.2}"
1311 );
1312 }
1313
1314 #[test]
1315 fn test_dopri5_error_norm_small_step() {
1316 let f = |_t: f64, y: &[f64]| vec![-y[0]];
1318 let (_, _, err) = dopri5_step(&f, 0.0, &[1.0], 0.01, 1e-6, 1e-8);
1319 assert!(err < 1.0, "error norm should be < 1 for h=0.01: {err}");
1320 }
1321
1322 #[test]
1325 fn test_bptt_gradient_nonzero_and_finite() {
1326 let func = NeuralOdeFunc::new(2, 4, 99);
1328 let solver = NeuralOdeSolver::new(func, 1e-3, 1e-6);
1329 let mut adj = AdjointMethod::new(2);
1330 let z_final = vec![0.5, -0.3];
1331 let loss_grad = vec![1.0, 0.0];
1332 let (_, grad_params) = adj.run(&solver, &z_final, &loss_grad, 0.0, 0.5, 0.1);
1333 assert_eq!(grad_params.len(), solver.func.n_params());
1334 assert!(
1335 grad_params.iter().all(|v| v.is_finite()),
1336 "some parameter gradients are non-finite"
1337 );
1338 assert!(
1339 grad_params.iter().any(|v| v.abs() > 1e-15),
1340 "all parameter gradients are zero"
1341 );
1342 }
1343}