pub struct AdjointMethod {
pub augmented_state: Vec<f64>,
pub state_dim: usize,
}Expand description
Reverse-mode gradient of a Neural ODE loss via the continuous adjoint method.
The adjoint state a(t) = -dL/dz(t) is integrated backward in time. This
gives parameter gradients without storing the full forward trajectory.
Fields§
§augmented_state: Vec<f64>Augmented state [z; a; dL/dθ] during backward integration.
state_dim: usizeDimensionality of the ODE state.
Implementations§
Source§impl AdjointMethod
impl AdjointMethod
Sourcepub fn new(state_dim: usize) -> Self
pub fn new(state_dim: usize) -> Self
Construct a new AdjointMethod for an ODE with state dimension state_dim.
Sourcepub fn backward(&self, loss_grad: &[f64]) -> Vec<f64>
pub fn backward(&self, loss_grad: &[f64]) -> Vec<f64>
Compute parameter gradients given the loss gradient at the final time.
This is a simplified adjoint implementation: it propagates loss_grad
backward through one RK4 step and returns the approximate gradient with
respect to the initial state.
In a full implementation, func would be called to integrate the adjoint
ODE backward; here we use a finite-difference approximation to illustrate
the interface.
Sourcepub fn run(
&mut self,
solver: &NeuralOdeSolver,
z_final: &[f64],
loss_grad: &[f64],
t0: f64,
t1: f64,
dt: f64,
) -> (Vec<f64>, Vec<f64>)
pub fn run( &mut self, solver: &NeuralOdeSolver, z_final: &[f64], loss_grad: &[f64], t0: f64, t1: f64, dt: f64, ) -> (Vec<f64>, Vec<f64>)
Set the final adjoint state from loss_grad and propagate it backward
through solver from t1 to t0 using RK4.
Returns (grad_z0, grad_params) where grad_z0 is the gradient with
respect to the initial state and grad_params is a flat vector of
approximate parameter gradients.
Trait Implementations§
Source§impl Clone for AdjointMethod
impl Clone for AdjointMethod
Source§fn clone(&self) -> AdjointMethod
fn clone(&self) -> AdjointMethod
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreAuto Trait Implementations§
impl Freeze for AdjointMethod
impl RefUnwindSafe for AdjointMethod
impl Send for AdjointMethod
impl Sync for AdjointMethod
impl Unpin for AdjointMethod
impl UnsafeUnpin for AdjointMethod
impl UnwindSafe for AdjointMethod
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
Source§fn to_subset(&self) -> Option<SS>
fn to_subset(&self) -> Option<SS>
self from the equivalent element of its
superset. Read moreSource§fn is_in_subset(&self) -> bool
fn is_in_subset(&self) -> bool
self is actually part of its subset T (and can be converted to it).Source§fn to_subset_unchecked(&self) -> SS
fn to_subset_unchecked(&self) -> SS
self.to_subset but without any property checks. Always succeeds.Source§fn from_subset(element: &SS) -> SP
fn from_subset(element: &SS) -> SP
self to the equivalent element of its superset.