1use crate::tensor::Tensor;
8use cjc_repro::kahan_sum_f64;
9
10pub fn ode_step_euler(y: &Tensor, dydt: &Tensor, h: f64) -> Tensor {
24 let y_data = y.to_vec();
25 let dy_data = dydt.to_vec();
26 assert_eq!(y_data.len(), dy_data.len(), "ode_step_euler: y and dydt must have same length");
27
28 let result: Vec<f64> = y_data.iter().zip(dy_data.iter())
29 .map(|(&yi, &dyi)| yi + h * dyi)
30 .collect();
31 Tensor::from_vec_unchecked(result, y.shape())
32}
33
34pub fn ode_step_rk4(y: &Tensor, k1: &Tensor, k2: &Tensor, k3: &Tensor, k4: &Tensor, h: f64) -> Tensor {
44 let y_data = y.to_vec();
45 let k1_data = k1.to_vec();
46 let k2_data = k2.to_vec();
47 let k3_data = k3.to_vec();
48 let k4_data = k4.to_vec();
49 let n = y_data.len();
50 assert_eq!(k1_data.len(), n);
51 assert_eq!(k2_data.len(), n);
52 assert_eq!(k3_data.len(), n);
53 assert_eq!(k4_data.len(), n);
54
55 let h6 = h / 6.0;
56 let result: Vec<f64> = (0..n)
57 .map(|i| {
58 let terms = [
60 k1_data[i],
61 2.0 * k2_data[i],
62 2.0 * k3_data[i],
63 k4_data[i],
64 ];
65 y_data[i] + h6 * kahan_sum_f64(&terms)
66 })
67 .collect();
68 Tensor::from_vec_unchecked(result, y.shape())
69}
70
71pub fn pde_laplacian_1d(u: &Tensor, dx: f64) -> Tensor {
86 let data = u.to_vec();
87 let n = data.len();
88 let dx2_inv = 1.0 / (dx * dx);
89 let mut lap = vec![0.0_f64; n];
90
91 for i in 1..n - 1 {
92 lap[i] = (data[i - 1] - 2.0 * data[i] + data[i + 1]) * dx2_inv;
93 }
94
95 Tensor::from_vec_unchecked(lap, u.shape())
96}
97
98pub fn pde_step_diffusion(u: &Tensor, alpha: f64, dt: f64, dx: f64) -> Tensor {
110 let lap = pde_laplacian_1d(u, dx);
111 let u_data = u.to_vec();
112 let lap_data = lap.to_vec();
113 let result: Vec<f64> = u_data.iter().zip(lap_data.iter())
114 .map(|(&ui, &li)| ui + dt * alpha * li)
115 .collect();
116 Tensor::from_vec_unchecked(result, u.shape())
117}
118
119#[derive(Debug, Clone, PartialEq)]
129pub enum SymExpr {
130 Const(f64),
132 Var(String),
134 Add(Box<SymExpr>, Box<SymExpr>),
136 Mul(Box<SymExpr>, Box<SymExpr>),
138 Pow(Box<SymExpr>, f64),
140 Sin(Box<SymExpr>),
142 Cos(Box<SymExpr>),
144 Exp(Box<SymExpr>),
146 Ln(Box<SymExpr>),
148 Neg(Box<SymExpr>),
150}
151
152impl SymExpr {
153 pub fn differentiate(&self, var: &str) -> SymExpr {
155 match self {
156 SymExpr::Const(_) => SymExpr::Const(0.0),
157 SymExpr::Var(name) => {
158 if name == var {
159 SymExpr::Const(1.0)
160 } else {
161 SymExpr::Const(0.0)
162 }
163 }
164 SymExpr::Add(a, b) => SymExpr::Add(
165 Box::new(a.differentiate(var)),
166 Box::new(b.differentiate(var)),
167 ),
168 SymExpr::Mul(a, b) => {
169 SymExpr::Add(
171 Box::new(SymExpr::Mul(
172 Box::new(a.differentiate(var)),
173 b.clone(),
174 )),
175 Box::new(SymExpr::Mul(
176 a.clone(),
177 Box::new(b.differentiate(var)),
178 )),
179 )
180 }
181 SymExpr::Pow(base, exp) => {
182 SymExpr::Mul(
184 Box::new(SymExpr::Mul(
185 Box::new(SymExpr::Const(*exp)),
186 Box::new(SymExpr::Pow(base.clone(), exp - 1.0)),
187 )),
188 Box::new(base.differentiate(var)),
189 )
190 }
191 SymExpr::Sin(inner) => {
192 SymExpr::Mul(
194 Box::new(SymExpr::Cos(inner.clone())),
195 Box::new(inner.differentiate(var)),
196 )
197 }
198 SymExpr::Cos(inner) => {
199 SymExpr::Mul(
201 Box::new(SymExpr::Neg(Box::new(SymExpr::Sin(inner.clone())))),
202 Box::new(inner.differentiate(var)),
203 )
204 }
205 SymExpr::Exp(inner) => {
206 SymExpr::Mul(
208 Box::new(SymExpr::Exp(inner.clone())),
209 Box::new(inner.differentiate(var)),
210 )
211 }
212 SymExpr::Ln(inner) => {
213 SymExpr::Mul(
215 Box::new(SymExpr::Pow(inner.clone(), -1.0)),
216 Box::new(inner.differentiate(var)),
217 )
218 }
219 SymExpr::Neg(inner) => {
220 SymExpr::Neg(Box::new(inner.differentiate(var)))
221 }
222 }
223 }
224
225 pub fn eval(&self, bindings: &std::collections::BTreeMap<String, f64>) -> f64 {
227 match self {
228 SymExpr::Const(c) => *c,
229 SymExpr::Var(name) => *bindings.get(name).unwrap_or(&0.0),
230 SymExpr::Add(a, b) => a.eval(bindings) + b.eval(bindings),
231 SymExpr::Mul(a, b) => a.eval(bindings) * b.eval(bindings),
232 SymExpr::Pow(base, exp) => base.eval(bindings).powf(*exp),
233 SymExpr::Sin(inner) => inner.eval(bindings).sin(),
234 SymExpr::Cos(inner) => inner.eval(bindings).cos(),
235 SymExpr::Exp(inner) => inner.eval(bindings).exp(),
236 SymExpr::Ln(inner) => inner.eval(bindings).ln(),
237 SymExpr::Neg(inner) => -inner.eval(bindings),
238 }
239 }
240}
241
242#[allow(dead_code)]
248fn tensor_add(a: &Tensor, b: &Tensor) -> Tensor {
249 let a_data = a.to_vec();
250 let b_data = b.to_vec();
251 debug_assert_eq!(a_data.len(), b_data.len());
252 let result: Vec<f64> = a_data.iter().zip(b_data.iter()).map(|(&ai, &bi)| ai + bi).collect();
253 Tensor::from_vec_unchecked(result, a.shape())
254}
255
256fn tensor_scale(a: &Tensor, scalar: f64) -> Tensor {
258 let a_data = a.to_vec();
259 let result: Vec<f64> = a_data.iter().map(|&ai| scalar * ai).collect();
260 Tensor::from_vec_unchecked(result, a.shape())
261}
262
263fn tensor_add_scaled(a: &Tensor, b: &Tensor, scalar: f64) -> Tensor {
265 let a_data = a.to_vec();
266 let b_data = b.to_vec();
267 debug_assert_eq!(a_data.len(), b_data.len());
268 let result: Vec<f64> = a_data.iter().zip(b_data.iter()).map(|(&ai, &bi)| ai + scalar * bi).collect();
269 Tensor::from_vec_unchecked(result, a.shape())
270}
271
272fn tensor_norm(a: &Tensor) -> f64 {
274 let data = a.to_vec();
275 let terms: Vec<f64> = data.iter().map(|&x| x * x).collect();
276 kahan_sum_f64(&terms).sqrt()
277}
278
279pub fn ode_solve_rk4<F>(
293 mut f: F,
294 y0: &Tensor,
295 t_span: (f64, f64),
296 n_steps: usize,
297) -> (Vec<f64>, Vec<Tensor>)
298where
299 F: FnMut(f64, &Tensor) -> Tensor,
300{
301 assert!(n_steps > 0, "ode_solve_rk4: n_steps must be > 0");
302 let (t0, t1) = t_span;
303 let h = (t1 - t0) / n_steps as f64;
304
305 let mut ts = Vec::with_capacity(n_steps + 1);
306 let mut ys = Vec::with_capacity(n_steps + 1);
307
308 ts.push(t0);
309 ys.push(y0.clone());
310
311 let mut t = t0;
312 let mut y = y0.clone();
313
314 for _ in 0..n_steps {
315 let k1 = f(t, &y);
316 let y2 = tensor_add_scaled(&y, &k1, h * 0.5);
317 let k2 = f(t + h * 0.5, &y2);
318 let y3 = tensor_add_scaled(&y, &k2, h * 0.5);
319 let k3 = f(t + h * 0.5, &y3);
320 let y4 = tensor_add_scaled(&y, &k3, h);
321 let k4 = f(t + h, &y4);
322
323 y = ode_step_rk4(&y, &k1, &k2, &k3, &k4, h);
325 t += h;
326
327 ts.push(t);
328 ys.push(y.clone());
329 }
330
331 (ts, ys)
332}
333
334mod dp5 {
337 pub const C2: f64 = 1.0 / 5.0;
338 pub const C3: f64 = 3.0 / 10.0;
339 pub const C4: f64 = 4.0 / 5.0;
340 pub const C5: f64 = 8.0 / 9.0;
341 pub const A21: f64 = 1.0 / 5.0;
344 pub const A31: f64 = 3.0 / 40.0;
345 pub const A32: f64 = 9.0 / 40.0;
346 pub const A41: f64 = 44.0 / 45.0;
347 pub const A42: f64 = -56.0 / 15.0;
348 pub const A43: f64 = 32.0 / 9.0;
349 pub const A51: f64 = 19372.0 / 6561.0;
350 pub const A52: f64 = -25360.0 / 2187.0;
351 pub const A53: f64 = 64448.0 / 6561.0;
352 pub const A54: f64 = -212.0 / 729.0;
353 pub const A61: f64 = 9017.0 / 3168.0;
354 pub const A62: f64 = -355.0 / 33.0;
355 pub const A63: f64 = 46732.0 / 5247.0;
356 pub const A64: f64 = 49.0 / 176.0;
357 pub const A65: f64 = -5103.0 / 18656.0;
358
359 pub const B1: f64 = 35.0 / 384.0;
361 pub const B3: f64 = 500.0 / 1113.0;
363 pub const B4: f64 = 125.0 / 192.0;
364 pub const B5: f64 = -2187.0 / 6784.0;
365 pub const B6: f64 = 11.0 / 84.0;
366 pub const E1: f64 = 71.0 / 57600.0;
371 pub const E3: f64 = -71.0 / 16695.0;
373 pub const E4: f64 = 71.0 / 1920.0;
374 pub const E5: f64 = -17253.0 / 339200.0;
375 pub const E6: f64 = 22.0 / 525.0;
376 pub const E7: f64 = -1.0 / 40.0;
377}
378
379pub fn ode_solve_rk45<F>(
395 mut f: F,
396 y0: &Tensor,
397 t_span: (f64, f64),
398 rtol: f64,
399 atol: f64,
400) -> (Vec<f64>, Vec<Tensor>)
401where
402 F: FnMut(f64, &Tensor) -> Tensor,
403{
404 let (t0, t1) = t_span;
405 assert!(t1 > t0, "ode_solve_rk45: t1 must be > t0");
406
407 let mut ts = Vec::new();
408 let mut ys = Vec::new();
409
410 ts.push(t0);
411 ys.push(y0.clone());
412
413 let n = y0.to_vec().len();
414
415 let f0 = f(t0, y0);
417 let f0_norm = tensor_norm(&f0).max(1e-300);
418 let mut h = (0.01 * (t1 - t0)).min(0.1 / f0_norm);
419 h = h.max(1e-12);
420
421 let mut t = t0;
422 let mut y = y0.clone();
423 let safety = 0.9_f64;
424 let max_factor = 10.0_f64;
425 let min_factor = 0.2_f64;
426 let max_steps = 1_000_000_usize;
427 let mut step_count = 0;
428
429 while t < t1 && step_count < max_steps {
430 if t + h > t1 {
432 h = t1 - t;
433 }
434
435 let k1 = f(t, &y);
437 let y2 = tensor_add_scaled(&y, &k1, h * dp5::A21);
438 let k2 = f(t + dp5::C2 * h, &y2);
439 let mut y3_data = y.to_vec();
441 let k1d = k1.to_vec(); let k2d = k2.to_vec();
442 for i in 0..n {
443 y3_data[i] += h * (dp5::A31 * k1d[i] + dp5::A32 * k2d[i]);
444 }
445 let y3 = Tensor::from_vec_unchecked(y3_data, y.shape());
446 let k3 = f(t + dp5::C3 * h, &y3);
447 let k3d = k3.to_vec();
449 let mut y4_data = y.to_vec();
450 for i in 0..n {
451 y4_data[i] += h * (dp5::A41 * k1d[i] + dp5::A42 * k2d[i] + dp5::A43 * k3d[i]);
452 }
453 let y4 = Tensor::from_vec_unchecked(y4_data, y.shape());
454 let k4 = f(t + dp5::C4 * h, &y4);
455 let k4d = k4.to_vec();
457 let mut y5_data = y.to_vec();
458 for i in 0..n {
459 y5_data[i] += h * (dp5::A51 * k1d[i] + dp5::A52 * k2d[i] + dp5::A53 * k3d[i] + dp5::A54 * k4d[i]);
460 }
461 let y5 = Tensor::from_vec_unchecked(y5_data, y.shape());
462 let k5 = f(t + dp5::C5 * h, &y5);
463 let k5d = k5.to_vec();
465 let mut y6_data = y.to_vec();
466 for i in 0..n {
467 y6_data[i] += h * (dp5::A61 * k1d[i] + dp5::A62 * k2d[i] + dp5::A63 * k3d[i] + dp5::A64 * k4d[i] + dp5::A65 * k5d[i]);
468 }
469 let y6 = Tensor::from_vec_unchecked(y6_data, y.shape());
470 let k6 = f(t + h, &y6);
471 let k6d = k6.to_vec();
473 let y_data = y.to_vec();
474 let mut y5th_data = vec![0.0_f64; n];
475 for i in 0..n {
476 let terms = [
477 dp5::B1 * k1d[i],
478 dp5::B3 * k3d[i],
479 dp5::B4 * k4d[i],
480 dp5::B5 * k5d[i],
481 dp5::B6 * k6d[i],
482 ];
483 y5th_data[i] = y_data[i] + h * kahan_sum_f64(&terms);
484 }
485 let y5th = Tensor::from_vec_unchecked(y5th_data.clone(), y.shape());
486
487 let k7 = f(t + h, &y5th);
489 let k7d = k7.to_vec();
490
491 let mut err_sq_acc = 0.0_f64;
493 for i in 0..n {
494 let e_terms = [
495 dp5::E1 * k1d[i],
496 dp5::E3 * k3d[i],
497 dp5::E4 * k4d[i],
498 dp5::E5 * k5d[i],
499 dp5::E6 * k6d[i],
500 dp5::E7 * k7d[i],
501 ];
502 let e_i = h * kahan_sum_f64(&e_terms);
503 let sc = atol + rtol * y5th_data[i].abs().max(y_data[i].abs());
504 err_sq_acc += (e_i / sc) * (e_i / sc);
505 }
506 let err_norm = (err_sq_acc / n as f64).sqrt();
507
508 if err_norm <= 1.0 {
509 t += h;
511 y = y5th;
512 ts.push(t);
513 ys.push(y.clone());
514 step_count += 1;
515
516 let factor = safety * err_norm.powf(-0.2).min(max_factor).max(min_factor);
518 h = (h * factor).min(t1 - t);
519 if h < 1e-14 {
520 break;
521 }
522 } else {
523 let factor = (safety * err_norm.powf(-0.25)).max(min_factor);
525 h *= factor;
526 if h < 1e-14 {
527 break;
528 }
529 }
530 }
531
532 (ts, ys)
533}
534
535pub fn adjoint_solve<F, G>(
561 mut f: F,
562 mut grad_f: G,
563 y_final: &Tensor,
564 t_span: (f64, f64),
565 n_steps: usize,
566) -> (Tensor, Tensor)
567where
568 F: FnMut(f64, &Tensor) -> Tensor,
569 G: FnMut(f64, &Tensor, &Tensor) -> (Tensor, Tensor),
570{
571 assert!(n_steps > 0, "adjoint_solve: n_steps must be > 0");
572 let (t0, t1) = t_span;
573 let h = (t1 - t0) / n_steps as f64;
575
576 let n = y_final.to_vec().len();
577
578 let a0 = Tensor::from_vec_unchecked(vec![0.0_f64; n], y_final.shape());
582
583 let mut t = t1;
584 let mut y = y_final.clone();
585 let mut a = a0;
586
587 for _ in 0..n_steps {
588 let t_prev = t - h;
589
590 let ky1 = tensor_scale(&f(t, &y), -1.0);
598 let (ka1, _) = grad_f(t, &y, &a);
600
601 let y2 = tensor_add_scaled(&y, &ky1, h * 0.5);
603 let a2 = tensor_add_scaled(&a, &ka1, h * 0.5);
604 let ky2 = tensor_scale(&f(t - h * 0.5, &y2), -1.0);
605 let (ka2, _) = grad_f(t - h * 0.5, &y2, &a2);
606
607 let y3 = tensor_add_scaled(&y, &ky2, h * 0.5);
609 let a3 = tensor_add_scaled(&a, &ka2, h * 0.5);
610 let ky3 = tensor_scale(&f(t - h * 0.5, &y3), -1.0);
611 let (ka3, _) = grad_f(t - h * 0.5, &y3, &a3);
612
613 let y4 = tensor_add_scaled(&y, &ky3, h);
615 let a4 = tensor_add_scaled(&a, &ka3, h);
616 let ky4 = tensor_scale(&f(t_prev, &y4), -1.0);
617 let (ka4, _) = grad_f(t_prev, &y4, &a4);
618
619 y = ode_step_rk4(&y, &ky1, &ky2, &ky3, &ky4, h);
621 a = ode_step_rk4(&a, &ka1, &ka2, &ka3, &ka4, h);
622 t = t_prev;
623 }
624
625 (y, a)
626}
627
628#[cfg(test)]
633mod tests {
634 use super::*;
635 use std::collections::BTreeMap;
636
637 #[test]
638 fn test_euler_step() {
639 let y = Tensor::from_vec_unchecked(vec![1.0, 0.0], &[2]);
640 let dydt = Tensor::from_vec_unchecked(vec![0.0, 1.0], &[2]);
641 let y1 = ode_step_euler(&y, &dydt, 0.1);
642 let result = y1.to_vec();
643 assert!((result[0] - 1.0).abs() < 1e-15);
644 assert!((result[1] - 0.1).abs() < 1e-15);
645 }
646
647 #[test]
648 fn test_rk4_step_constant() {
649 let y = Tensor::from_vec_unchecked(vec![1.0], &[1]);
650 let k = Tensor::from_vec_unchecked(vec![2.0], &[1]);
651 let y1 = ode_step_rk4(&y, &k, &k, &k, &k, 0.1);
652 assert!((y1.to_vec()[0] - 1.2).abs() < 1e-14);
654 }
655
656 #[test]
657 fn test_laplacian_1d() {
658 let u = Tensor::from_vec_unchecked(vec![0.0, 1.0, 4.0, 9.0, 16.0], &[5]);
661 let lap = pde_laplacian_1d(&u, 1.0);
662 let data = lap.to_vec();
663 assert!((data[0] - 0.0).abs() < 1e-14); assert!((data[1] - 2.0).abs() < 1e-14);
665 assert!((data[2] - 2.0).abs() < 1e-14);
666 assert!((data[3] - 2.0).abs() < 1e-14);
667 assert!((data[4] - 0.0).abs() < 1e-14); }
669
670 #[test]
671 fn test_symbolic_diff_polynomial() {
672 let expr = SymExpr::Pow(Box::new(SymExpr::Var("x".into())), 3.0);
674 let deriv = expr.differentiate("x");
675
676 let mut bindings = BTreeMap::new();
677 bindings.insert("x".into(), 2.0);
678
679 let val = deriv.eval(&bindings);
680 assert!((val - 12.0).abs() < 1e-12); }
682
683 #[test]
684 fn test_symbolic_diff_sin() {
685 let expr = SymExpr::Sin(Box::new(SymExpr::Var("x".into())));
687 let deriv = expr.differentiate("x");
688
689 let mut bindings = BTreeMap::new();
690 bindings.insert("x".into(), 0.0);
691
692 let val = deriv.eval(&bindings);
693 assert!((val - 1.0).abs() < 1e-12); }
695
696 #[test]
699 fn test_rk4_exponential_decay() {
700 let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
703 let f = |_t: f64, y: &Tensor| -> Tensor {
704 tensor_scale(y, -1.0)
705 };
706 let (ts, ys) = ode_solve_rk4(f, &y0, (0.0, 1.0), 100);
707
708 assert_eq!(ts.len(), 101);
709 assert_eq!(ys.len(), 101);
710 assert!((ts[0] - 0.0).abs() < 1e-15);
711 assert!((ts[100] - 1.0).abs() < 1e-12);
712
713 let y_final = ys[100].to_vec()[0];
714 let exact = (-1.0_f64).exp();
715 assert!(
716 (y_final - exact).abs() < 1e-8,
717 "RK4 decay: got {}, expected {}",
718 y_final, exact
719 );
720 }
721
722 #[test]
723 fn test_rk4_harmonic_oscillator() {
724 let y0 = Tensor::from_vec_unchecked(vec![1.0, 0.0], &[2]);
728 let f = |_t: f64, y: &Tensor| -> Tensor {
729 let d = y.to_vec();
730 Tensor::from_vec_unchecked(vec![d[1], -d[0]], &[2])
731 };
732 let t_end = std::f64::consts::PI / 2.0;
733 let (ts, ys) = ode_solve_rk4(f, &y0, (0.0, t_end), 1000);
734
735 let y_end = ys.last().unwrap().to_vec();
736 assert!(
738 y_end[0].abs() < 1e-7,
739 "harmonic y(pi/2) should be ~0, got {}",
740 y_end[0]
741 );
742 assert!(
744 (y_end[1] - (-1.0)).abs() < 1e-7,
745 "harmonic v(pi/2) should be ~-1, got {}",
746 y_end[1]
747 );
748 let _ = ts;
749 }
750
751 #[test]
752 fn test_rk45_exponential_decay() {
753 let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
755 let f = |_t: f64, y: &Tensor| -> Tensor {
756 tensor_scale(y, -1.0)
757 };
758 let (ts, ys) = ode_solve_rk45(f, &y0, (0.0, 1.0), 1e-8, 1e-10);
759
760 assert!(!ts.is_empty(), "RK45 should produce at least one step");
761 let y_final = ys.last().unwrap().to_vec()[0];
762 let t_final = *ts.last().unwrap();
763 let exact = (-t_final).exp();
764 assert!(
765 (y_final - exact).abs() < 1e-6,
766 "RK45 decay: got {} at t={}, expected {}",
767 y_final, t_final, exact
768 );
769 }
770
771 #[test]
772 fn test_rk45_fewer_steps_than_rk4_fixed() {
773 let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
776
777 let f_adaptive = |_t: f64, y: &Tensor| -> Tensor { tensor_scale(y, -1.0) };
778 let f_fixed = |_t: f64, y: &Tensor| -> Tensor { tensor_scale(y, -1.0) };
779
780 let (ts_adaptive, _) = ode_solve_rk45(f_adaptive, &y0, (0.0, 1.0), 1e-6, 1e-8);
781 let (ts_fixed, _) = ode_solve_rk4(f_fixed, &y0, (0.0, 1.0), 1000);
782
783 assert!(
784 ts_adaptive.len() < ts_fixed.len(),
785 "RK45 adaptive ({} steps) should take fewer steps than RK4 fixed ({} steps)",
786 ts_adaptive.len() - 1,
787 ts_fixed.len() - 1
788 );
789 }
790
791 #[test]
792 fn test_rk4_determinism() {
793 let y0 = Tensor::from_vec_unchecked(vec![1.0, 0.5], &[2]);
794 let f = |_t: f64, y: &Tensor| -> Tensor {
795 let d = y.to_vec();
796 Tensor::from_vec_unchecked(vec![-0.5 * d[0], -0.3 * d[1]], &[2])
797 };
798
799 let (ts1, ys1) = ode_solve_rk4(|t, y| { let d = y.to_vec(); Tensor::from_vec_unchecked(vec![-0.5*d[0], -0.3*d[1]], &[2]) }, &y0, (0.0, 1.0), 50);
800 let (ts2, ys2) = ode_solve_rk4(|t, y| { let d = y.to_vec(); Tensor::from_vec_unchecked(vec![-0.5*d[0], -0.3*d[1]], &[2]) }, &y0, (0.0, 1.0), 50);
801
802 assert_eq!(ts1, ts2, "RK4 time points must be bit-identical");
803 for (y1, y2) in ys1.iter().zip(ys2.iter()) {
804 assert_eq!(y1.to_vec(), y2.to_vec(), "RK4 solutions must be bit-identical");
805 }
806 let _ = f;
807 }
808
809 #[test]
810 fn test_rk45_determinism() {
811 let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
812
813 let run = || ode_solve_rk45(
814 |_t, y| tensor_scale(y, -1.0),
815 &y0,
816 (0.0, 2.0),
817 1e-6,
818 1e-9,
819 );
820
821 let (ts1, ys1) = run();
822 let (ts2, ys2) = run();
823 assert_eq!(ts1, ts2, "RK45 time points must be bit-identical");
824 for (y1, y2) in ys1.iter().zip(ys2.iter()) {
825 assert_eq!(y1.to_vec(), y2.to_vec(), "RK45 solutions must be bit-identical");
826 }
827 }
828
829 #[test]
830 fn test_adjoint_linear_ode() {
831 let t0 = 0.0_f64;
836 let t1 = 1.0_f64;
837 let y_final = Tensor::from_vec_unchecked(vec![(-t1).exp()], &[1]);
838
839 let (y0_rec, _adj) = adjoint_solve(
840 |_t, y| tensor_scale(y, -1.0),
841 |_t, y, a| {
842 let adj_y = tensor_scale(a, 1.0); let adj_theta = Tensor::from_vec_unchecked(vec![0.0], &[1]);
847 (adj_y, adj_theta)
848 },
849 &y_final,
850 (t0, t1),
851 1000,
852 );
853
854 let y0_val = y0_rec.to_vec()[0];
855 assert!(
856 (y0_val - 1.0).abs() < 1e-6,
857 "adjoint_solve should recover y(0)=1.0, got {}",
858 y0_val
859 );
860 }
861
862 #[test]
863 fn test_adjoint_gradient_vs_finite_diff() {
864 let t1 = 0.5_f64;
872 let alpha = 1.0_f64;
873 let y_final_val = (alpha * t1).exp();
874 let y_final = Tensor::from_vec_unchecked(vec![y_final_val], &[1]);
875
876 let eps = 1e-5;
878 let l_plus = ((alpha + eps) * t1).exp();
879 let l_minus = ((alpha - eps) * t1).exp();
880 let fd_grad = (l_plus - l_minus) / (2.0 * eps);
881
882 let a_terminal = Tensor::from_vec_unchecked(vec![1.0_f64], &[1]);
886
887 let n_steps = 500;
890 let h = t1 / n_steps as f64;
891 let mut t = t1;
892 let mut y = y_final.clone();
893 let mut a = a_terminal;
894
895 let mut grad_alpha_acc = 0.0_f64;
897
898 for _ in 0..n_steps {
899 let t_prev = t - h;
900 let ky1 = tensor_scale(&tensor_scale(&y, alpha), -1.0);
902 let ka1 = tensor_scale(&a, -(-alpha)); let y2 = tensor_add_scaled(&y, &ky1, h * 0.5);
905 let a2 = tensor_add_scaled(&a, &ka1, h * 0.5);
906 let ky2 = tensor_scale(&tensor_scale(&y2, alpha), -1.0);
907 let ka2 = tensor_scale(&a2, alpha);
908
909 let y3 = tensor_add_scaled(&y, &ky2, h * 0.5);
910 let a3 = tensor_add_scaled(&a, &ka2, h * 0.5);
911 let ky3 = tensor_scale(&tensor_scale(&y3, alpha), -1.0);
912 let ka3 = tensor_scale(&a3, alpha);
913
914 let y4 = tensor_add_scaled(&y, &ky3, h);
915 let a4 = tensor_add_scaled(&a, &ka3, h);
916 let ky4 = tensor_scale(&tensor_scale(&y4, alpha), -1.0);
917 let ka4 = tensor_scale(&a4, alpha);
918
919 let ay = a.to_vec()[0] * y.to_vec()[0];
922 grad_alpha_acc += h * ay;
923
924 y = ode_step_rk4(&y, &ky1, &ky2, &ky3, &ky4, h);
925 a = ode_step_rk4(&a, &ka1, &ka2, &ka3, &ka4, h);
926 t = t_prev;
927 }
928
929 assert!(
931 (grad_alpha_acc - fd_grad).abs() / fd_grad.abs() < 1e-4,
932 "adjoint gradient {} should match finite diff {} (rel err = {})",
933 grad_alpha_acc, fd_grad,
934 (grad_alpha_acc - fd_grad).abs() / fd_grad.abs()
935 );
936 }
937
938 #[test]
939 fn test_adjoint_determinism() {
940 let y_final = Tensor::from_vec_unchecked(vec![(-1.0_f64).exp()], &[1]);
941
942 let run = || adjoint_solve(
943 |_t, y| tensor_scale(y, -1.0),
944 |_t, _y, a| (tensor_scale(a, 1.0), Tensor::from_vec_unchecked(vec![0.0], &[1])),
945 &y_final,
946 (0.0, 1.0),
947 100,
948 );
949
950 let (y1, a1) = run();
951 let (y2, a2) = run();
952 assert_eq!(y1.to_vec(), y2.to_vec(), "adjoint_solve y0 must be bit-identical");
953 assert_eq!(a1.to_vec(), a2.to_vec(), "adjoint_solve adjoint must be bit-identical");
954 }
955}