pub struct Gradients { /* private fields */ }
Expand description
A generic container for keeping variable sized arrays associated with a UniqueId.
You can:
- Insert array values into it
- Remove entries
- Access references to arrays
- Access mutable references to arrays
This structure is similar to a HashMap, where all the methods require a key implementing UniqueId and HasArrayType.
Under the hood, it actually is a HashMap, and stores values as Box
Implementations
sourceimpl Gradients
impl Gradients
sourcepub fn mut_and_ref<L, R>(&mut self, l: &L, r: &R) -> (&mut L::Array, &R::Array) where
L: HasUniqueId + HasArrayType + HasDevice,
R: HasUniqueId + HasArrayType,
pub fn mut_and_ref<L, R>(&mut self, l: &L, r: &R) -> (&mut L::Array, &R::Array) where
L: HasUniqueId + HasArrayType + HasDevice,
R: HasUniqueId + HasArrayType,
Borrows a pair of a gradients (&mut L, &R)
.
l
is the gradient to update, and r
is the gradient to backprop.
Panics if l
and r
have the same id.
Examples:
let a = Tensor1D::new([1.0, 2.0, 3.0]);
let b: Tensor1D<5> = Tensor1D::zeros();
let mut gradients: Gradients = Default::default();
*gradients.mut_gradient(&a) = [-4.0, 5.0, -6.0];
*gradients.mut_gradient(&b) = [1.0, 2.0, 3.0, 4.0, 5.0];
let (g_a, g_b) = gradients.mut_and_ref(&a, &b);
assert_eq!(g_a, &mut [-4.0, 5.0, -6.0]);
assert_eq!(g_b, &[1.0, 2.0, 3.0, 4.0, 5.0]);
pub fn muts_and_ref<L1, L2, L3, R>(
&mut self,
l1: &L1,
l2: &L2,
l3: &L3,
r: &R
) -> (&mut L1::Array, &mut L2::Array, &mut L3::Array, &R::Array) where
L1: HasUniqueId + HasArrayType + HasDevice,
L2: HasUniqueId + HasArrayType + HasDevice,
L3: HasUniqueId + HasArrayType + HasDevice,
R: HasUniqueId + HasArrayType,
sourcepub fn remove<T: HasUniqueId + HasArrayType>(
&mut self,
t: &T
) -> Option<Box<T::Array>>
pub fn remove<T: HasUniqueId + HasArrayType>(
&mut self,
t: &T
) -> Option<Box<T::Array>>
Removes and returns the data associated with t.id()
.
Panics if data associated with t
is not found. This indicates an unrecoverable bug.
Example usage:
let t = Tensor1D::new([1.0, 2.0, 3.0]);
let mut gradients: Gradients = Default::default();
*gradients.mut_gradient(&t) = [-4.0, 5.0, -6.0];
assert_eq!(gradients.remove(&t).expect("").as_ref(), &[-4.0, 5.0, -6.0]);
sourcepub fn mut_gradient<T: HasUniqueId + HasArrayType + HasDevice>(
&mut self,
t: &T
) -> &mut T::Array
pub fn mut_gradient<T: HasUniqueId + HasArrayType + HasDevice>(
&mut self,
t: &T
) -> &mut T::Array
Returns a mutable reference to the data associated with t
.
If no data is associated with t
, then AllocateZeros::zeros is called
to allocate the data.
Example usage:
let t = Tensor1D::new([1.0, 2.0, 3.0]);
let mut gradients: Gradients = Default::default();
let g: &mut [f32; 3] = gradients.mut_gradient(&t);
assert_eq!(g, &mut [0.0, 0.0, 0.0]);
g[0] = 1.0;
assert_eq!(gradients.ref_gradient(&t), &[1.0, 0.0, 0.0]);
sourcepub fn ref_gradient<T: HasUniqueId + HasArrayType>(&self, t: &T) -> &T::Array
pub fn ref_gradient<T: HasUniqueId + HasArrayType>(&self, t: &T) -> &T::Array
Returns a reference to the data associated with t
.
Panics
If no data is associated with t
yet, this will panic due to an unwrap()
on a .get() to the underlying hashmap.
Example usage:
let t = Tensor1D::new([1.0, 2.0, 3.0]);
let mut gradients: Gradients = Default::default();
gradients.mut_gradient(&t);
assert_eq!(gradients.ref_gradient(&t), &[0.0, 0.0, 0.0]);
Trait Implementations
Auto Trait Implementations
impl !RefUnwindSafe for Gradients
impl !Send for Gradients
impl !Sync for Gradients
impl Unpin for Gradients
impl !UnwindSafe for Gradients
Blanket Implementations
sourceimpl<T> BorrowMut<T> for T where
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more