pub struct GradStore<R: Runtime> { /* private fields */ }Expand description
Storage for gradients computed during backward pass
Gradients are stored by tensor ID and accumulated when a tensor is used multiple times in the computation graph.
Implementations§
Source§impl<R: Runtime> GradStore<R>
impl<R: Runtime> GradStore<R>
Sourcepub fn insert(&mut self, id: TensorId, grad: Tensor<R>)
pub fn insert(&mut self, id: TensorId, grad: Tensor<R>)
Insert a gradient (overwrites if exists)
Sourcepub fn accumulate<F>(&mut self, id: TensorId, grad: Tensor<R>, add_fn: F)
pub fn accumulate<F>(&mut self, id: TensorId, grad: Tensor<R>, add_fn: F)
Accumulate a gradient for a tensor
If no gradient exists for this tensor, stores the gradient. If a gradient already exists, adds the new gradient to the existing one.
§Arguments
id- The tensor ID to accumulate gradient forgrad- The gradient tensor to accumulateadd_fn- Function to add two tensors:fn(existing, new) -> sum
This is used when a tensor is used multiple times in the computation graph, requiring its gradients to be summed according to the chain rule.
Sourcepub fn try_accumulate<F>(
&mut self,
id: TensorId,
grad: Tensor<R>,
add_fn: F,
) -> Result<()>
pub fn try_accumulate<F>( &mut self, id: TensorId, grad: Tensor<R>, add_fn: F, ) -> Result<()>
Accumulate a gradient with a fallible addition function
Like accumulate, but the addition function can fail and return a Result.
This is the preferred method for use in backward passes where tensor
operations may fail.
Sourcepub fn accumulate_or_insert(&mut self, id: TensorId, grad: Tensor<R>)
pub fn accumulate_or_insert(&mut self, id: TensorId, grad: Tensor<R>)
Insert a gradient, overwriting any existing value
This is intentionally simpler than accumulate - it directly inserts
the gradient without addition. Use this when you don’t have access to
a TensorOps client, or when overwriting semantics are desired.
For proper gradient accumulation (adding to existing gradients),
use accumulate with an add function instead.
Trait Implementations§
Auto Trait Implementations§
impl<R> Freeze for GradStore<R>
impl<R> RefUnwindSafe for GradStore<R>
impl<R> Send for GradStore<R>
impl<R> Sync for GradStore<R>
impl<R> Unpin for GradStore<R>
impl<R> UnsafeUnpin for GradStore<R>
impl<R> UnwindSafe for GradStore<R>
Blanket Implementations§
Source§impl<T> ArchivePointee for T
impl<T> ArchivePointee for T
Source§type ArchivedMetadata = ()
type ArchivedMetadata = ()
Source§fn pointer_metadata(
_: &<T as ArchivePointee>::ArchivedMetadata,
) -> <T as Pointee>::Metadata
fn pointer_metadata( _: &<T as ArchivePointee>::ArchivedMetadata, ) -> <T as Pointee>::Metadata
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> LayoutRaw for T
impl<T> LayoutRaw for T
Source§fn layout_raw(_: <T as Pointee>::Metadata) -> Result<Layout, LayoutError>
fn layout_raw(_: <T as Pointee>::Metadata) -> Result<Layout, LayoutError>
Source§impl<T, N1, N2> Niching<NichedOption<T, N1>> for N2
impl<T, N1, N2> Niching<NichedOption<T, N1>> for N2
Source§unsafe fn is_niched(niched: *const NichedOption<T, N1>) -> bool
unsafe fn is_niched(niched: *const NichedOption<T, N1>) -> bool
Source§fn resolve_niched(out: Place<NichedOption<T, N1>>)
fn resolve_niched(out: Place<NichedOption<T, N1>>)
out indicating that a T is niched.