Skip to main content

adjoint_solve

Function adjoint_solve 

Source
pub fn adjoint_solve<F, G>(
    f: F,
    grad_f: G,
    y_final: &Tensor,
    t_span: (f64, f64),
    n_steps: usize,
) -> (Tensor, Tensor)
where F: FnMut(f64, &Tensor) -> Tensor, G: FnMut(f64, &Tensor, &Tensor) -> (Tensor, Tensor),
Expand description

Adjoint method for Neural ODEs — O(1) memory gradient computation.

Given the final state y(T) and a loss gradient (adjoint at T), integrates the adjoint ODE backward in time to recover y(0) and the adjoint a(t0).

The augmented backward system is: dy/dt = f(t, y) (forward ODE — integrated backward) da/dt = -a^T * (df/dy) (adjoint ODE)

Here grad_f provides both:

  • The Jacobian-vector product a^T * J_y f, i.e. (df/dy)^T * a
  • The gradient w.r.t. parameters: a^T * (df/dtheta)

This implementation uses RK4 backward integration for reproducibility.

§Arguments

  • f - Forward dynamics: f(t, y) → dy/dt
  • grad_f - Returns (vjp_y, vjp_theta): Jacobian-vector product with adjoint. Signature: grad_f(t, y, adjoint) → (adj_dot wrt y, adj_dot wrt params)
  • y_final - State at final time T
  • t_span - (t0, T) — integrates BACKWARD from T to t0
  • n_steps - Number of backward integration steps

§Returns

(y0_reconstructed, adjoint_at_t0)