pub struct ActivationCheckpointer { /* private fields */ }Expand description
Manages activation checkpoints for a multi-layer forward pass.
Stores activations at designated layer boundaries so they can be used during backpropagation without keeping every intermediate result in memory.
Implementations§
Source§impl ActivationCheckpointer
impl ActivationCheckpointer
Sourcepub fn new(config: CheckpointConfig) -> Self
pub fn new(config: CheckpointConfig) -> Self
Create a new checkpointer with the given configuration.
Sourcepub fn save(
&mut self,
layer_idx: usize,
activation: Array1<f32>,
) -> ModelResult<()>
pub fn save( &mut self, layer_idx: usize, activation: Array1<f32>, ) -> ModelResult<()>
Save an activation at the given layer index.
If use_mixed_precision is enabled, the activation is quantised
before storage.
§Errors
Returns an error if max_checkpoints would be exceeded and the
layer is not a checkpoint boundary.
Sourcepub fn get(&self, layer_idx: usize) -> ModelResult<&Array1<f32>>
pub fn get(&self, layer_idx: usize) -> ModelResult<&Array1<f32>>
Retrieve the checkpointed activation at the given layer.
§Errors
Returns an error if no checkpoint exists for layer_idx.
Sourcepub fn memory_saved_bytes(&self) -> usize
pub fn memory_saved_bytes(&self) -> usize
Estimated bytes of memory saved by not storing non-checkpointed activations.
This value is updated during checkpointed_forward calls.
Sourcepub fn memory_stored_bytes(&self) -> usize
pub fn memory_stored_bytes(&self) -> usize
Bytes currently stored in checkpoints.
Sourcepub fn num_checkpoints(&self) -> usize
pub fn num_checkpoints(&self) -> usize
Number of non-None checkpoints currently held.
Sourcepub fn is_checkpoint_layer(&self, layer_idx: usize) -> bool
pub fn is_checkpoint_layer(&self, layer_idx: usize) -> bool
Whether a given layer index is a checkpoint boundary according to the current configuration.
Sourcepub fn checkpointed_forward<F>(
&mut self,
input: &Array1<f32>,
layers: &[usize],
forward_fn: F,
) -> ModelResult<Array1<f32>>
pub fn checkpointed_forward<F>( &mut self, input: &Array1<f32>, layers: &[usize], forward_fn: F, ) -> ModelResult<Array1<f32>>
Run a checkpointed forward pass through the given layers.
The forward_fn is called sequentially for each layer in layers,
receiving the current activation and the layer index. Activations
at checkpoint boundaries are saved; others are discarded (their
memory cost is recorded in bytes_saved).
§Parameters
input: the initial activation fed into the first layer.layers: ordered list of layer indices to process.forward_fn:Fn(&Array1<f32>, usize) -> ModelResult<Array1<f32>>; applies one layer’s computation.
§Returns
The activation after all layers have been applied.
Sourcepub fn recompute_from_checkpoint<F>(
&self,
target_layer: usize,
layers: &[usize],
forward_fn: F,
) -> ModelResult<Array1<f32>>
pub fn recompute_from_checkpoint<F>( &self, target_layer: usize, layers: &[usize], forward_fn: F, ) -> ModelResult<Array1<f32>>
Recompute activations from the nearest checkpoint up to target_layer.
This is used during the backward pass: find the closest checkpoint
before target_layer, then replay the forward function from there.
§Parameters
target_layer: the layer whose activation is needed.layers: the full ordered list of layer indices.forward_fn: the same forward function used during the forward pass.
§Returns
The recomputed activation at target_layer.
Sourcepub fn config(&self) -> &CheckpointConfig
pub fn config(&self) -> &CheckpointConfig
Return the configuration.
Trait Implementations§
Source§impl Clone for ActivationCheckpointer
impl Clone for ActivationCheckpointer
Source§fn clone(&self) -> ActivationCheckpointer
fn clone(&self) -> ActivationCheckpointer
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreAuto Trait Implementations§
impl Freeze for ActivationCheckpointer
impl RefUnwindSafe for ActivationCheckpointer
impl Send for ActivationCheckpointer
impl Sync for ActivationCheckpointer
impl Unpin for ActivationCheckpointer
impl UnsafeUnpin for ActivationCheckpointer
impl UnwindSafe for ActivationCheckpointer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more