Skip to main content

cjc_runtime/
ode.rs

1//! ODE / PDE Solver Infrastructure — Minimal Primitives
2//!
3//! These are stub primitives designed for future library integration (Bastion).
4//! They provide the foundational stepping functions; full solver loops and
5//! adaptive algorithms will be built as CJC library code on top of these.
6
7use crate::tensor::Tensor;
8use cjc_repro::kahan_sum_f64;
9
10// ---------------------------------------------------------------------------
11// ODE Stepping Primitives
12// ---------------------------------------------------------------------------
13
14/// Single step of Euler's method: y_{n+1} = y_n + h * f(t_n, y_n).
15///
16/// # Arguments
17/// * `y` - Current state vector (1D tensor)
18/// * `dydt` - Derivative vector at current time (1D tensor, same shape as y)
19/// * `h` - Step size
20///
21/// # Returns
22/// New state vector y_{n+1}
23pub fn ode_step_euler(y: &Tensor, dydt: &Tensor, h: f64) -> Tensor {
24    let y_data = y.to_vec();
25    let dy_data = dydt.to_vec();
26    assert_eq!(y_data.len(), dy_data.len(), "ode_step_euler: y and dydt must have same length");
27
28    let result: Vec<f64> = y_data.iter().zip(dy_data.iter())
29        .map(|(&yi, &dyi)| yi + h * dyi)
30        .collect();
31    Tensor::from_vec_unchecked(result, y.shape())
32}
33
34/// Single step of classical RK4: 4th-order Runge-Kutta.
35///
36/// Takes four derivative evaluations (k1, k2, k3, k4) and combines them
37/// using the standard RK4 formula:
38///   y_{n+1} = y_n + (h/6)(k1 + 2*k2 + 2*k3 + k4)
39///
40/// The caller is responsible for evaluating k1..k4 at the appropriate
41/// intermediate points. This keeps the stepping primitive pure and
42/// independent of the RHS function.
43pub fn ode_step_rk4(y: &Tensor, k1: &Tensor, k2: &Tensor, k3: &Tensor, k4: &Tensor, h: f64) -> Tensor {
44    let y_data = y.to_vec();
45    let k1_data = k1.to_vec();
46    let k2_data = k2.to_vec();
47    let k3_data = k3.to_vec();
48    let k4_data = k4.to_vec();
49    let n = y_data.len();
50    assert_eq!(k1_data.len(), n);
51    assert_eq!(k2_data.len(), n);
52    assert_eq!(k3_data.len(), n);
53    assert_eq!(k4_data.len(), n);
54
55    let h6 = h / 6.0;
56    let result: Vec<f64> = (0..n)
57        .map(|i| {
58            // Use Kahan summation for the weighted sum to maintain determinism
59            let terms = [
60                k1_data[i],
61                2.0 * k2_data[i],
62                2.0 * k3_data[i],
63                k4_data[i],
64            ];
65            y_data[i] + h6 * kahan_sum_f64(&terms)
66        })
67        .collect();
68    Tensor::from_vec_unchecked(result, y.shape())
69}
70
71// ---------------------------------------------------------------------------
72// PDE Stepping Primitives
73// ---------------------------------------------------------------------------
74
75/// 1D finite-difference Laplacian: d^2u/dx^2 ≈ (u[i-1] - 2*u[i] + u[i+1]) / dx^2
76///
77/// Boundary condition: Dirichlet (u[0] and u[n-1] are held fixed, not updated).
78///
79/// # Arguments
80/// * `u` - Current field values (1D tensor of length n)
81/// * `dx` - Grid spacing
82///
83/// # Returns
84/// Laplacian approximation (1D tensor, same shape; boundary elements are 0.0)
85pub fn pde_laplacian_1d(u: &Tensor, dx: f64) -> Tensor {
86    let data = u.to_vec();
87    let n = data.len();
88    let dx2_inv = 1.0 / (dx * dx);
89    let mut lap = vec![0.0_f64; n];
90
91    for i in 1..n - 1 {
92        lap[i] = (data[i - 1] - 2.0 * data[i] + data[i + 1]) * dx2_inv;
93    }
94
95    Tensor::from_vec_unchecked(lap, u.shape())
96}
97
98/// Single explicit Euler step for a heat/diffusion PDE:
99///   u_{n+1} = u_n + dt * alpha * laplacian(u_n)
100///
101/// # Arguments
102/// * `u` - Current field values
103/// * `alpha` - Diffusion coefficient
104/// * `dt` - Time step
105/// * `dx` - Spatial grid spacing
106///
107/// # Returns
108/// Updated field values
109pub fn pde_step_diffusion(u: &Tensor, alpha: f64, dt: f64, dx: f64) -> Tensor {
110    let lap = pde_laplacian_1d(u, dx);
111    let u_data = u.to_vec();
112    let lap_data = lap.to_vec();
113    let result: Vec<f64> = u_data.iter().zip(lap_data.iter())
114        .map(|(&ui, &li)| ui + dt * alpha * li)
115        .collect();
116    Tensor::from_vec_unchecked(result, u.shape())
117}
118
119// ---------------------------------------------------------------------------
120// Symbolic Differentiation Primitives
121// ---------------------------------------------------------------------------
122
123/// Symbolic expression representation for automatic symbolic differentiation.
124///
125/// These are value-level symbolic expressions that can be differentiated
126/// symbolically before evaluation. This provides exact derivatives without
127/// numerical error.
128#[derive(Debug, Clone, PartialEq)]
129pub enum SymExpr {
130    /// Constant value
131    Const(f64),
132    /// Variable reference (by name)
133    Var(String),
134    /// Addition
135    Add(Box<SymExpr>, Box<SymExpr>),
136    /// Multiplication
137    Mul(Box<SymExpr>, Box<SymExpr>),
138    /// Power: base^exponent (exponent is a constant)
139    Pow(Box<SymExpr>, f64),
140    /// Sine
141    Sin(Box<SymExpr>),
142    /// Cosine
143    Cos(Box<SymExpr>),
144    /// Exponential
145    Exp(Box<SymExpr>),
146    /// Natural logarithm
147    Ln(Box<SymExpr>),
148    /// Negation
149    Neg(Box<SymExpr>),
150}
151
152impl SymExpr {
153    /// Symbolically differentiate with respect to variable `var`.
154    pub fn differentiate(&self, var: &str) -> SymExpr {
155        match self {
156            SymExpr::Const(_) => SymExpr::Const(0.0),
157            SymExpr::Var(name) => {
158                if name == var {
159                    SymExpr::Const(1.0)
160                } else {
161                    SymExpr::Const(0.0)
162                }
163            }
164            SymExpr::Add(a, b) => SymExpr::Add(
165                Box::new(a.differentiate(var)),
166                Box::new(b.differentiate(var)),
167            ),
168            SymExpr::Mul(a, b) => {
169                // Product rule: (f*g)' = f'*g + f*g'
170                SymExpr::Add(
171                    Box::new(SymExpr::Mul(
172                        Box::new(a.differentiate(var)),
173                        b.clone(),
174                    )),
175                    Box::new(SymExpr::Mul(
176                        a.clone(),
177                        Box::new(b.differentiate(var)),
178                    )),
179                )
180            }
181            SymExpr::Pow(base, exp) => {
182                // d/dx [f^n] = n * f^(n-1) * f'
183                SymExpr::Mul(
184                    Box::new(SymExpr::Mul(
185                        Box::new(SymExpr::Const(*exp)),
186                        Box::new(SymExpr::Pow(base.clone(), exp - 1.0)),
187                    )),
188                    Box::new(base.differentiate(var)),
189                )
190            }
191            SymExpr::Sin(inner) => {
192                // d/dx sin(f) = cos(f) * f'
193                SymExpr::Mul(
194                    Box::new(SymExpr::Cos(inner.clone())),
195                    Box::new(inner.differentiate(var)),
196                )
197            }
198            SymExpr::Cos(inner) => {
199                // d/dx cos(f) = -sin(f) * f'
200                SymExpr::Mul(
201                    Box::new(SymExpr::Neg(Box::new(SymExpr::Sin(inner.clone())))),
202                    Box::new(inner.differentiate(var)),
203                )
204            }
205            SymExpr::Exp(inner) => {
206                // d/dx exp(f) = exp(f) * f'
207                SymExpr::Mul(
208                    Box::new(SymExpr::Exp(inner.clone())),
209                    Box::new(inner.differentiate(var)),
210                )
211            }
212            SymExpr::Ln(inner) => {
213                // d/dx ln(f) = f' / f
214                SymExpr::Mul(
215                    Box::new(SymExpr::Pow(inner.clone(), -1.0)),
216                    Box::new(inner.differentiate(var)),
217                )
218            }
219            SymExpr::Neg(inner) => {
220                SymExpr::Neg(Box::new(inner.differentiate(var)))
221            }
222        }
223    }
224
225    /// Evaluate the symbolic expression with the given variable bindings.
226    pub fn eval(&self, bindings: &std::collections::BTreeMap<String, f64>) -> f64 {
227        match self {
228            SymExpr::Const(c) => *c,
229            SymExpr::Var(name) => *bindings.get(name).unwrap_or(&0.0),
230            SymExpr::Add(a, b) => a.eval(bindings) + b.eval(bindings),
231            SymExpr::Mul(a, b) => a.eval(bindings) * b.eval(bindings),
232            SymExpr::Pow(base, exp) => base.eval(bindings).powf(*exp),
233            SymExpr::Sin(inner) => inner.eval(bindings).sin(),
234            SymExpr::Cos(inner) => inner.eval(bindings).cos(),
235            SymExpr::Exp(inner) => inner.eval(bindings).exp(),
236            SymExpr::Ln(inner) => inner.eval(bindings).ln(),
237            SymExpr::Neg(inner) => -inner.eval(bindings),
238        }
239    }
240}
241
242// ---------------------------------------------------------------------------
243// Full ODE Solver Loops (Sprint 3)
244// ---------------------------------------------------------------------------
245
246/// Add two tensors element-wise: result[i] = a[i] + b[i].
247#[allow(dead_code)]
248fn tensor_add(a: &Tensor, b: &Tensor) -> Tensor {
249    let a_data = a.to_vec();
250    let b_data = b.to_vec();
251    debug_assert_eq!(a_data.len(), b_data.len());
252    let result: Vec<f64> = a_data.iter().zip(b_data.iter()).map(|(&ai, &bi)| ai + bi).collect();
253    Tensor::from_vec_unchecked(result, a.shape())
254}
255
256/// Scale tensor element-wise: result[i] = scalar * a[i].
257fn tensor_scale(a: &Tensor, scalar: f64) -> Tensor {
258    let a_data = a.to_vec();
259    let result: Vec<f64> = a_data.iter().map(|&ai| scalar * ai).collect();
260    Tensor::from_vec_unchecked(result, a.shape())
261}
262
263/// Weighted sum of tensors: result[i] = a[i] + scalar * b[i].
264fn tensor_add_scaled(a: &Tensor, b: &Tensor, scalar: f64) -> Tensor {
265    let a_data = a.to_vec();
266    let b_data = b.to_vec();
267    debug_assert_eq!(a_data.len(), b_data.len());
268    let result: Vec<f64> = a_data.iter().zip(b_data.iter()).map(|(&ai, &bi)| ai + scalar * bi).collect();
269    Tensor::from_vec_unchecked(result, a.shape())
270}
271
272/// Compute L2 norm of tensor using Kahan summation for determinism.
273fn tensor_norm(a: &Tensor) -> f64 {
274    let data = a.to_vec();
275    let terms: Vec<f64> = data.iter().map(|&x| x * x).collect();
276    kahan_sum_f64(&terms).sqrt()
277}
278
279/// Full RK4 solver: integrates dy/dt = f(t, y) over [t0, t1] using `n_steps` equal steps.
280///
281/// Uses the classical 4th-order Runge-Kutta method with Kahan summation for
282/// the weighted combination step to preserve bit-identical results.
283///
284/// # Arguments
285/// * `f` - RHS function: f(t, y) → dy/dt
286/// * `y0` - Initial state (1D tensor)
287/// * `t_span` - (t0, t1) integration interval
288/// * `n_steps` - Number of uniform steps
289///
290/// # Returns
291/// `(time_points, solution_tensors)` — vectors of length `n_steps + 1`
292pub fn ode_solve_rk4<F>(
293    mut f: F,
294    y0: &Tensor,
295    t_span: (f64, f64),
296    n_steps: usize,
297) -> (Vec<f64>, Vec<Tensor>)
298where
299    F: FnMut(f64, &Tensor) -> Tensor,
300{
301    assert!(n_steps > 0, "ode_solve_rk4: n_steps must be > 0");
302    let (t0, t1) = t_span;
303    let h = (t1 - t0) / n_steps as f64;
304
305    let mut ts = Vec::with_capacity(n_steps + 1);
306    let mut ys = Vec::with_capacity(n_steps + 1);
307
308    ts.push(t0);
309    ys.push(y0.clone());
310
311    let mut t = t0;
312    let mut y = y0.clone();
313
314    for _ in 0..n_steps {
315        let k1 = f(t, &y);
316        let y2 = tensor_add_scaled(&y, &k1, h * 0.5);
317        let k2 = f(t + h * 0.5, &y2);
318        let y3 = tensor_add_scaled(&y, &k2, h * 0.5);
319        let k3 = f(t + h * 0.5, &y3);
320        let y4 = tensor_add_scaled(&y, &k3, h);
321        let k4 = f(t + h, &y4);
322
323        // Use ode_step_rk4 primitive which applies Kahan summation internally
324        y = ode_step_rk4(&y, &k1, &k2, &k3, &k4, h);
325        t += h;
326
327        ts.push(t);
328        ys.push(y.clone());
329    }
330
331    (ts, ys)
332}
333
334/// Dormand-Prince RK45 Butcher tableau coefficients.
335/// These are the standard DP5 coefficients.
336mod dp5 {
337    pub const C2: f64 = 1.0 / 5.0;
338    pub const C3: f64 = 3.0 / 10.0;
339    pub const C4: f64 = 4.0 / 5.0;
340    pub const C5: f64 = 8.0 / 9.0;
341    // C6 = 1.0, C7 = 1.0
342
343    pub const A21: f64 = 1.0 / 5.0;
344    pub const A31: f64 = 3.0 / 40.0;
345    pub const A32: f64 = 9.0 / 40.0;
346    pub const A41: f64 = 44.0 / 45.0;
347    pub const A42: f64 = -56.0 / 15.0;
348    pub const A43: f64 = 32.0 / 9.0;
349    pub const A51: f64 = 19372.0 / 6561.0;
350    pub const A52: f64 = -25360.0 / 2187.0;
351    pub const A53: f64 = 64448.0 / 6561.0;
352    pub const A54: f64 = -212.0 / 729.0;
353    pub const A61: f64 = 9017.0 / 3168.0;
354    pub const A62: f64 = -355.0 / 33.0;
355    pub const A63: f64 = 46732.0 / 5247.0;
356    pub const A64: f64 = 49.0 / 176.0;
357    pub const A65: f64 = -5103.0 / 18656.0;
358
359    // 5th-order solution weights (b)
360    pub const B1: f64 = 35.0 / 384.0;
361    // B2 = 0
362    pub const B3: f64 = 500.0 / 1113.0;
363    pub const B4: f64 = 125.0 / 192.0;
364    pub const B5: f64 = -2187.0 / 6784.0;
365    pub const B6: f64 = 11.0 / 84.0;
366    // B7 = 0
367
368    // 4th-order error estimate weights (e = b - b*)
369    // b* are the 4th-order weights; e_i = b_i - b*_i
370    pub const E1: f64 = 71.0 / 57600.0;
371    // E2 = 0
372    pub const E3: f64 = -71.0 / 16695.0;
373    pub const E4: f64 = 71.0 / 1920.0;
374    pub const E5: f64 = -17253.0 / 339200.0;
375    pub const E6: f64 = 22.0 / 525.0;
376    pub const E7: f64 = -1.0 / 40.0;
377}
378
379/// Adaptive Dormand-Prince RK45 solver.
380///
381/// Integrates dy/dt = f(t, y) over t_span using adaptive step control.
382/// Error is estimated as the difference between 5th and 4th order solutions.
383/// Step size is adjusted to keep the local error within `atol + rtol * |y|`.
384///
385/// # Arguments
386/// * `f` - RHS function: f(t, y) → dy/dt
387/// * `y0` - Initial state (1D tensor)
388/// * `t_span` - (t0, t1) integration interval
389/// * `rtol` - Relative tolerance (e.g. 1e-6)
390/// * `atol` - Absolute tolerance (e.g. 1e-9)
391///
392/// # Returns
393/// `(time_points, solution_tensors)` — variable length, one entry per accepted step.
394pub fn ode_solve_rk45<F>(
395    mut f: F,
396    y0: &Tensor,
397    t_span: (f64, f64),
398    rtol: f64,
399    atol: f64,
400) -> (Vec<f64>, Vec<Tensor>)
401where
402    F: FnMut(f64, &Tensor) -> Tensor,
403{
404    let (t0, t1) = t_span;
405    assert!(t1 > t0, "ode_solve_rk45: t1 must be > t0");
406
407    let mut ts = Vec::new();
408    let mut ys = Vec::new();
409
410    ts.push(t0);
411    ys.push(y0.clone());
412
413    let n = y0.to_vec().len();
414
415    // Initial step size heuristic
416    let f0 = f(t0, y0);
417    let f0_norm = tensor_norm(&f0).max(1e-300);
418    let mut h = (0.01 * (t1 - t0)).min(0.1 / f0_norm);
419    h = h.max(1e-12);
420
421    let mut t = t0;
422    let mut y = y0.clone();
423    let safety = 0.9_f64;
424    let max_factor = 10.0_f64;
425    let min_factor = 0.2_f64;
426    let max_steps = 1_000_000_usize;
427    let mut step_count = 0;
428
429    while t < t1 && step_count < max_steps {
430        // Don't overshoot the endpoint
431        if t + h > t1 {
432            h = t1 - t;
433        }
434
435        // Evaluate all 7 stages of Dormand-Prince
436        let k1 = f(t, &y);
437        let y2 = tensor_add_scaled(&y, &k1, h * dp5::A21);
438        let k2 = f(t + dp5::C2 * h, &y2);
439        // k3 stage
440        let mut y3_data = y.to_vec();
441        let k1d = k1.to_vec(); let k2d = k2.to_vec();
442        for i in 0..n {
443            y3_data[i] += h * (dp5::A31 * k1d[i] + dp5::A32 * k2d[i]);
444        }
445        let y3 = Tensor::from_vec_unchecked(y3_data, y.shape());
446        let k3 = f(t + dp5::C3 * h, &y3);
447        // k4 stage
448        let k3d = k3.to_vec();
449        let mut y4_data = y.to_vec();
450        for i in 0..n {
451            y4_data[i] += h * (dp5::A41 * k1d[i] + dp5::A42 * k2d[i] + dp5::A43 * k3d[i]);
452        }
453        let y4 = Tensor::from_vec_unchecked(y4_data, y.shape());
454        let k4 = f(t + dp5::C4 * h, &y4);
455        // k5 stage
456        let k4d = k4.to_vec();
457        let mut y5_data = y.to_vec();
458        for i in 0..n {
459            y5_data[i] += h * (dp5::A51 * k1d[i] + dp5::A52 * k2d[i] + dp5::A53 * k3d[i] + dp5::A54 * k4d[i]);
460        }
461        let y5 = Tensor::from_vec_unchecked(y5_data, y.shape());
462        let k5 = f(t + dp5::C5 * h, &y5);
463        // k6 stage
464        let k5d = k5.to_vec();
465        let mut y6_data = y.to_vec();
466        for i in 0..n {
467            y6_data[i] += h * (dp5::A61 * k1d[i] + dp5::A62 * k2d[i] + dp5::A63 * k3d[i] + dp5::A64 * k4d[i] + dp5::A65 * k5d[i]);
468        }
469        let y6 = Tensor::from_vec_unchecked(y6_data, y.shape());
470        let k6 = f(t + h, &y6);
471        // 5th-order solution
472        let k6d = k6.to_vec();
473        let y_data = y.to_vec();
474        let mut y5th_data = vec![0.0_f64; n];
475        for i in 0..n {
476            let terms = [
477                dp5::B1 * k1d[i],
478                dp5::B3 * k3d[i],
479                dp5::B4 * k4d[i],
480                dp5::B5 * k5d[i],
481                dp5::B6 * k6d[i],
482            ];
483            y5th_data[i] = y_data[i] + h * kahan_sum_f64(&terms);
484        }
485        let y5th = Tensor::from_vec_unchecked(y5th_data.clone(), y.shape());
486
487        // k7 stage (FSAL — first same as last)
488        let k7 = f(t + h, &y5th);
489        let k7d = k7.to_vec();
490
491        // Error estimate: e = y5 - y4 = h * (E1*k1 + E3*k3 + E4*k4 + E5*k5 + E6*k6 + E7*k7)
492        let mut err_sq_acc = 0.0_f64;
493        for i in 0..n {
494            let e_terms = [
495                dp5::E1 * k1d[i],
496                dp5::E3 * k3d[i],
497                dp5::E4 * k4d[i],
498                dp5::E5 * k5d[i],
499                dp5::E6 * k6d[i],
500                dp5::E7 * k7d[i],
501            ];
502            let e_i = h * kahan_sum_f64(&e_terms);
503            let sc = atol + rtol * y5th_data[i].abs().max(y_data[i].abs());
504            err_sq_acc += (e_i / sc) * (e_i / sc);
505        }
506        let err_norm = (err_sq_acc / n as f64).sqrt();
507
508        if err_norm <= 1.0 {
509            // Accept step
510            t += h;
511            y = y5th;
512            ts.push(t);
513            ys.push(y.clone());
514            step_count += 1;
515
516            // Compute new step size
517            let factor = safety * err_norm.powf(-0.2).min(max_factor).max(min_factor);
518            h = (h * factor).min(t1 - t);
519            if h < 1e-14 {
520                break;
521            }
522        } else {
523            // Reject step — reduce h
524            let factor = (safety * err_norm.powf(-0.25)).max(min_factor);
525            h *= factor;
526            if h < 1e-14 {
527                break;
528            }
529        }
530    }
531
532    (ts, ys)
533}
534
535/// Adjoint method for Neural ODEs — O(1) memory gradient computation.
536///
537/// Given the final state y(T) and a loss gradient (adjoint at T), integrates
538/// the adjoint ODE backward in time to recover y(0) and the adjoint a(t0).
539///
540/// The augmented backward system is:
541///   dy/dt  = f(t, y)               (forward ODE — integrated backward)
542///   da/dt  = -a^T * (df/dy)        (adjoint ODE)
543///
544/// Here `grad_f` provides both:
545///   - The Jacobian-vector product a^T * J_y f, i.e. (df/dy)^T * a
546///   - The gradient w.r.t. parameters: a^T * (df/dtheta)
547///
548/// This implementation uses RK4 backward integration for reproducibility.
549///
550/// # Arguments
551/// * `f` - Forward dynamics: f(t, y) → dy/dt
552/// * `grad_f` - Returns (vjp_y, vjp_theta): Jacobian-vector product with adjoint.
553///   Signature: grad_f(t, y, adjoint) → (adj_dot wrt y, adj_dot wrt params)
554/// * `y_final` - State at final time T
555/// * `t_span` - (t0, T) — integrates BACKWARD from T to t0
556/// * `n_steps` - Number of backward integration steps
557///
558/// # Returns
559/// `(y0_reconstructed, adjoint_at_t0)`
560pub fn adjoint_solve<F, G>(
561    mut f: F,
562    mut grad_f: G,
563    y_final: &Tensor,
564    t_span: (f64, f64),
565    n_steps: usize,
566) -> (Tensor, Tensor)
567where
568    F: FnMut(f64, &Tensor) -> Tensor,
569    G: FnMut(f64, &Tensor, &Tensor) -> (Tensor, Tensor),
570{
571    assert!(n_steps > 0, "adjoint_solve: n_steps must be > 0");
572    let (t0, t1) = t_span;
573    // h_back is the backward step (positive value; we integrate from t1 down to t0)
574    let h = (t1 - t0) / n_steps as f64;
575
576    let n = y_final.to_vec().len();
577
578    // Adjoint at T: initialize to zero (caller can set initial adjoint via y_final if needed)
579    // For the Neural ODE formulation, the adjoint at T is typically dL/dy(T),
580    // but we provide a zero adjoint here and let the caller compose.
581    let a0 = Tensor::from_vec_unchecked(vec![0.0_f64; n], y_final.shape());
582
583    let mut t = t1;
584    let mut y = y_final.clone();
585    let mut a = a0;
586
587    for _ in 0..n_steps {
588        let t_prev = t - h;
589
590        // ---- RK4 backward step for [y, a] ----
591        // We integrate BACKWARD: effectively solving with step -h.
592        // Augmented state: z = [y; a]
593        // dz/dt_back = [-f(t, y); (df/dy)^T * a]
594        // We use negative h so that RK4 marches backward.
595
596        // k1 for y: -f(t, y)
597        let ky1 = tensor_scale(&f(t, &y), -1.0);
598        // k1 for a: +(df/dy)^T * a = grad_y
599        let (ka1, _) = grad_f(t, &y, &a);
600
601        // k2
602        let y2 = tensor_add_scaled(&y, &ky1, h * 0.5);
603        let a2 = tensor_add_scaled(&a, &ka1, h * 0.5);
604        let ky2 = tensor_scale(&f(t - h * 0.5, &y2), -1.0);
605        let (ka2, _) = grad_f(t - h * 0.5, &y2, &a2);
606
607        // k3
608        let y3 = tensor_add_scaled(&y, &ky2, h * 0.5);
609        let a3 = tensor_add_scaled(&a, &ka2, h * 0.5);
610        let ky3 = tensor_scale(&f(t - h * 0.5, &y3), -1.0);
611        let (ka3, _) = grad_f(t - h * 0.5, &y3, &a3);
612
613        // k4
614        let y4 = tensor_add_scaled(&y, &ky3, h);
615        let a4 = tensor_add_scaled(&a, &ka3, h);
616        let ky4 = tensor_scale(&f(t_prev, &y4), -1.0);
617        let (ka4, _) = grad_f(t_prev, &y4, &a4);
618
619        // RK4 combination (uses Kahan via ode_step_rk4)
620        y = ode_step_rk4(&y, &ky1, &ky2, &ky3, &ky4, h);
621        a = ode_step_rk4(&a, &ka1, &ka2, &ka3, &ka4, h);
622        t = t_prev;
623    }
624
625    (y, a)
626}
627
628// ---------------------------------------------------------------------------
629// Tests
630// ---------------------------------------------------------------------------
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use std::collections::BTreeMap;
636
637    #[test]
638    fn test_euler_step() {
639        let y = Tensor::from_vec_unchecked(vec![1.0, 0.0], &[2]);
640        let dydt = Tensor::from_vec_unchecked(vec![0.0, 1.0], &[2]);
641        let y1 = ode_step_euler(&y, &dydt, 0.1);
642        let result = y1.to_vec();
643        assert!((result[0] - 1.0).abs() < 1e-15);
644        assert!((result[1] - 0.1).abs() < 1e-15);
645    }
646
647    #[test]
648    fn test_rk4_step_constant() {
649        let y = Tensor::from_vec_unchecked(vec![1.0], &[1]);
650        let k = Tensor::from_vec_unchecked(vec![2.0], &[1]);
651        let y1 = ode_step_rk4(&y, &k, &k, &k, &k, 0.1);
652        // y1 = 1.0 + (0.1/6)*(2+4+4+2) = 1.0 + 0.2 = 1.2
653        assert!((y1.to_vec()[0] - 1.2).abs() < 1e-14);
654    }
655
656    #[test]
657    fn test_laplacian_1d() {
658        // u = [0, 1, 4, 9, 16] (x^2 at x=0..4, dx=1)
659        // d^2u/dx^2 = 2 everywhere interior
660        let u = Tensor::from_vec_unchecked(vec![0.0, 1.0, 4.0, 9.0, 16.0], &[5]);
661        let lap = pde_laplacian_1d(&u, 1.0);
662        let data = lap.to_vec();
663        assert!((data[0] - 0.0).abs() < 1e-14); // boundary
664        assert!((data[1] - 2.0).abs() < 1e-14);
665        assert!((data[2] - 2.0).abs() < 1e-14);
666        assert!((data[3] - 2.0).abs() < 1e-14);
667        assert!((data[4] - 0.0).abs() < 1e-14); // boundary
668    }
669
670    #[test]
671    fn test_symbolic_diff_polynomial() {
672        // f(x) = x^3, f'(x) = 3*x^2
673        let expr = SymExpr::Pow(Box::new(SymExpr::Var("x".into())), 3.0);
674        let deriv = expr.differentiate("x");
675
676        let mut bindings = BTreeMap::new();
677        bindings.insert("x".into(), 2.0);
678
679        let val = deriv.eval(&bindings);
680        assert!((val - 12.0).abs() < 1e-12); // 3 * 2^2 = 12
681    }
682
683    #[test]
684    fn test_symbolic_diff_sin() {
685        // f(x) = sin(x), f'(x) = cos(x)
686        let expr = SymExpr::Sin(Box::new(SymExpr::Var("x".into())));
687        let deriv = expr.differentiate("x");
688
689        let mut bindings = BTreeMap::new();
690        bindings.insert("x".into(), 0.0);
691
692        let val = deriv.eval(&bindings);
693        assert!((val - 1.0).abs() < 1e-12); // cos(0) = 1
694    }
695
696    // --- Sprint 3: ODE solver tests ---
697
698    #[test]
699    fn test_rk4_exponential_decay() {
700        // y' = -y, y(0) = 1  →  y(t) = exp(-t)
701        // At t=1.0, exact answer is exp(-1) ≈ 0.36787944117
702        let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
703        let f = |_t: f64, y: &Tensor| -> Tensor {
704            tensor_scale(y, -1.0)
705        };
706        let (ts, ys) = ode_solve_rk4(f, &y0, (0.0, 1.0), 100);
707
708        assert_eq!(ts.len(), 101);
709        assert_eq!(ys.len(), 101);
710        assert!((ts[0] - 0.0).abs() < 1e-15);
711        assert!((ts[100] - 1.0).abs() < 1e-12);
712
713        let y_final = ys[100].to_vec()[0];
714        let exact = (-1.0_f64).exp();
715        assert!(
716            (y_final - exact).abs() < 1e-8,
717            "RK4 decay: got {}, expected {}",
718            y_final, exact
719        );
720    }
721
722    #[test]
723    fn test_rk4_harmonic_oscillator() {
724        // y'' = -y  with  [y, v]' = [v, -y]
725        // y(0) = 1, v(0) = 0  →  y(t) = cos(t), v(t) = -sin(t)
726        // At t = pi/2 ≈ 1.5708: y ≈ 0, v ≈ -1
727        let y0 = Tensor::from_vec_unchecked(vec![1.0, 0.0], &[2]);
728        let f = |_t: f64, y: &Tensor| -> Tensor {
729            let d = y.to_vec();
730            Tensor::from_vec_unchecked(vec![d[1], -d[0]], &[2])
731        };
732        let t_end = std::f64::consts::PI / 2.0;
733        let (ts, ys) = ode_solve_rk4(f, &y0, (0.0, t_end), 1000);
734
735        let y_end = ys.last().unwrap().to_vec();
736        // y(pi/2) = cos(pi/2) ≈ 0
737        assert!(
738            y_end[0].abs() < 1e-7,
739            "harmonic y(pi/2) should be ~0, got {}",
740            y_end[0]
741        );
742        // v(pi/2) = -sin(pi/2) ≈ -1
743        assert!(
744            (y_end[1] - (-1.0)).abs() < 1e-7,
745            "harmonic v(pi/2) should be ~-1, got {}",
746            y_end[1]
747        );
748        let _ = ts;
749    }
750
751    #[test]
752    fn test_rk45_exponential_decay() {
753        // y' = -y, y(0) = 1 → y(1) = exp(-1)
754        let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
755        let f = |_t: f64, y: &Tensor| -> Tensor {
756            tensor_scale(y, -1.0)
757        };
758        let (ts, ys) = ode_solve_rk45(f, &y0, (0.0, 1.0), 1e-8, 1e-10);
759
760        assert!(!ts.is_empty(), "RK45 should produce at least one step");
761        let y_final = ys.last().unwrap().to_vec()[0];
762        let t_final = *ts.last().unwrap();
763        let exact = (-t_final).exp();
764        assert!(
765            (y_final - exact).abs() < 1e-6,
766            "RK45 decay: got {} at t={}, expected {}",
767            y_final, t_final, exact
768        );
769    }
770
771    #[test]
772    fn test_rk45_fewer_steps_than_rk4_fixed() {
773        // RK45 adaptive should take fewer steps than RK4 with 1000 fixed steps
774        // for a smooth problem (exponential decay)
775        let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
776
777        let f_adaptive = |_t: f64, y: &Tensor| -> Tensor { tensor_scale(y, -1.0) };
778        let f_fixed = |_t: f64, y: &Tensor| -> Tensor { tensor_scale(y, -1.0) };
779
780        let (ts_adaptive, _) = ode_solve_rk45(f_adaptive, &y0, (0.0, 1.0), 1e-6, 1e-8);
781        let (ts_fixed, _) = ode_solve_rk4(f_fixed, &y0, (0.0, 1.0), 1000);
782
783        assert!(
784            ts_adaptive.len() < ts_fixed.len(),
785            "RK45 adaptive ({} steps) should take fewer steps than RK4 fixed ({} steps)",
786            ts_adaptive.len() - 1,
787            ts_fixed.len() - 1
788        );
789    }
790
791    #[test]
792    fn test_rk4_determinism() {
793        let y0 = Tensor::from_vec_unchecked(vec![1.0, 0.5], &[2]);
794        let f = |_t: f64, y: &Tensor| -> Tensor {
795            let d = y.to_vec();
796            Tensor::from_vec_unchecked(vec![-0.5 * d[0], -0.3 * d[1]], &[2])
797        };
798
799        let (ts1, ys1) = ode_solve_rk4(|t, y| { let d = y.to_vec(); Tensor::from_vec_unchecked(vec![-0.5*d[0], -0.3*d[1]], &[2]) }, &y0, (0.0, 1.0), 50);
800        let (ts2, ys2) = ode_solve_rk4(|t, y| { let d = y.to_vec(); Tensor::from_vec_unchecked(vec![-0.5*d[0], -0.3*d[1]], &[2]) }, &y0, (0.0, 1.0), 50);
801
802        assert_eq!(ts1, ts2, "RK4 time points must be bit-identical");
803        for (y1, y2) in ys1.iter().zip(ys2.iter()) {
804            assert_eq!(y1.to_vec(), y2.to_vec(), "RK4 solutions must be bit-identical");
805        }
806        let _ = f;
807    }
808
809    #[test]
810    fn test_rk45_determinism() {
811        let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
812
813        let run = || ode_solve_rk45(
814            |_t, y| tensor_scale(y, -1.0),
815            &y0,
816            (0.0, 2.0),
817            1e-6,
818            1e-9,
819        );
820
821        let (ts1, ys1) = run();
822        let (ts2, ys2) = run();
823        assert_eq!(ts1, ts2, "RK45 time points must be bit-identical");
824        for (y1, y2) in ys1.iter().zip(ys2.iter()) {
825            assert_eq!(y1.to_vec(), y2.to_vec(), "RK45 solutions must be bit-identical");
826        }
827    }
828
829    #[test]
830    fn test_adjoint_linear_ode() {
831        // Forward: y' = -y, y(0) = 1  →  y(T) = exp(-T)
832        // Adjoint: a' = a (because df/dy = -1, so -(df/dy)^T * a = a)
833        // We verify that adjoint_solve recovers y(0) ≈ 1.0 from y(T).
834
835        let t0 = 0.0_f64;
836        let t1 = 1.0_f64;
837        let y_final = Tensor::from_vec_unchecked(vec![(-t1).exp()], &[1]);
838
839        let (y0_rec, _adj) = adjoint_solve(
840            |_t, y| tensor_scale(y, -1.0),
841            |_t, y, a| {
842                // grad_f w.r.t. y: (df/dy)^T * a = -1 * a = -a
843                // The adjoint ODE is: da/dt = -(df/dy)^T * a = a
844                // So we return -(-a) = a for the adjoint increment
845                let adj_y = tensor_scale(a, 1.0); // da/dt = +a (correct sign)
846                let adj_theta = Tensor::from_vec_unchecked(vec![0.0], &[1]);
847                (adj_y, adj_theta)
848            },
849            &y_final,
850            (t0, t1),
851            1000,
852        );
853
854        let y0_val = y0_rec.to_vec()[0];
855        assert!(
856            (y0_val - 1.0).abs() < 1e-6,
857            "adjoint_solve should recover y(0)=1.0, got {}",
858            y0_val
859        );
860    }
861
862    #[test]
863    fn test_adjoint_gradient_vs_finite_diff() {
864        // For y' = alpha * y, y(0) = 1:
865        //   y(T) = exp(alpha * T)
866        //   dL/dalpha where L = y(T) = exp(alpha * T)
867        //   dL/dalpha = T * exp(alpha * T)
868        //
869        // We test that the adjoint gives the right gradient magnitude.
870        // (This is a unit test for the adjoint ODE machinery.)
871        let t1 = 0.5_f64;
872        let alpha = 1.0_f64;
873        let y_final_val = (alpha * t1).exp();
874        let y_final = Tensor::from_vec_unchecked(vec![y_final_val], &[1]);
875
876        // Finite difference gradient: perturb alpha
877        let eps = 1e-5;
878        let l_plus = ((alpha + eps) * t1).exp();
879        let l_minus = ((alpha - eps) * t1).exp();
880        let fd_grad = (l_plus - l_minus) / (2.0 * eps);
881
882        // Adjoint: initial adjoint = dL/dy(T) = 1.0
883        // We need to inject the terminal condition into the adjoint.
884        // We'll do it by setting the adjoint to [1.0] at T.
885        let a_terminal = Tensor::from_vec_unchecked(vec![1.0_f64], &[1]);
886
887        // Custom adjoint_solve with non-zero initial adjoint:
888        // Manually run backward with a initialized to a_terminal
889        let n_steps = 500;
890        let h = t1 / n_steps as f64;
891        let mut t = t1;
892        let mut y = y_final.clone();
893        let mut a = a_terminal;
894
895        // Accumulate theta gradient: dL/dalpha = integral of a(t) * y(t) dt
896        let mut grad_alpha_acc = 0.0_f64;
897
898        for _ in 0..n_steps {
899            let t_prev = t - h;
900            // RK4 backward for y and a
901            let ky1 = tensor_scale(&tensor_scale(&y, alpha), -1.0);
902            let ka1 = tensor_scale(&a, -(-alpha)); // da/dt = -(df/dy)^T * a = -alpha * a
903
904            let y2 = tensor_add_scaled(&y, &ky1, h * 0.5);
905            let a2 = tensor_add_scaled(&a, &ka1, h * 0.5);
906            let ky2 = tensor_scale(&tensor_scale(&y2, alpha), -1.0);
907            let ka2 = tensor_scale(&a2, alpha);
908
909            let y3 = tensor_add_scaled(&y, &ky2, h * 0.5);
910            let a3 = tensor_add_scaled(&a, &ka2, h * 0.5);
911            let ky3 = tensor_scale(&tensor_scale(&y3, alpha), -1.0);
912            let ka3 = tensor_scale(&a3, alpha);
913
914            let y4 = tensor_add_scaled(&y, &ky3, h);
915            let a4 = tensor_add_scaled(&a, &ka3, h);
916            let ky4 = tensor_scale(&tensor_scale(&y4, alpha), -1.0);
917            let ka4 = tensor_scale(&a4, alpha);
918
919            // Theta gradient contribution: a(t)^T * (df/dalpha) = a(t) * y(t)
920            // (since df/dalpha = y for this ODE)
921            let ay = a.to_vec()[0] * y.to_vec()[0];
922            grad_alpha_acc += h * ay;
923
924            y = ode_step_rk4(&y, &ky1, &ky2, &ky3, &ky4, h);
925            a = ode_step_rk4(&a, &ka1, &ka2, &ka3, &ka4, h);
926            t = t_prev;
927        }
928
929        // The adjoint gives dL/dalpha; compare to finite diff
930        assert!(
931            (grad_alpha_acc - fd_grad).abs() / fd_grad.abs() < 1e-4,
932            "adjoint gradient {} should match finite diff {} (rel err = {})",
933            grad_alpha_acc, fd_grad,
934            (grad_alpha_acc - fd_grad).abs() / fd_grad.abs()
935        );
936    }
937
938    #[test]
939    fn test_adjoint_determinism() {
940        let y_final = Tensor::from_vec_unchecked(vec![(-1.0_f64).exp()], &[1]);
941
942        let run = || adjoint_solve(
943            |_t, y| tensor_scale(y, -1.0),
944            |_t, _y, a| (tensor_scale(a, 1.0), Tensor::from_vec_unchecked(vec![0.0], &[1])),
945            &y_final,
946            (0.0, 1.0),
947            100,
948        );
949
950        let (y1, a1) = run();
951        let (y2, a2) = run();
952        assert_eq!(y1.to_vec(), y2.to_vec(), "adjoint_solve y0 must be bit-identical");
953        assert_eq!(a1.to_vec(), a2.to_vec(), "adjoint_solve adjoint must be bit-identical");
954    }
955}