pub struct Gradients { /* private fields */ }
Expand description
A generic container for keeping variable sized arrays associated with a UniqueId.
You can:
- Insert array values into it
- Remove entries
- Access references to arrays
- Access mutable references to arrays
This structure is similar to a HashMap, where all the methods require a key implementing UniqueId and HasArrayType.
Under the hood, it actually is a HashMap, and stores values as Box
Implementations
sourceimpl Gradients
impl Gradients
sourcepub fn insert<T: HasUniqueId + HasArrayType>(
&mut self,
t: &T,
data: Box<T::Array>
)
pub fn insert<T: HasUniqueId + HasArrayType>(
&mut self,
t: &T,
data: Box<T::Array>
)
Insert’s data
associated with t.id()
.
Example usage:
let t = Tensor1D::new([1.0, 2.0, 3.0]);
let mut gradients: Gradients = Default::default();
gradients.insert(&t, Box::new([-4.0, 5.0, -6.0]));
assert_eq!(gradients.ref_gradient(&t), &[-4.0, 5.0, -6.0]);
sourcepub fn remove<T: HasUniqueId + HasArrayType>(
&mut self,
t: &T
) -> Option<Box<T::Array>>
pub fn remove<T: HasUniqueId + HasArrayType>(
&mut self,
t: &T
) -> Option<Box<T::Array>>
Removes and returns the data associated with t.id()
.
Example usage:
let t = Tensor1D::new([1.0, 2.0, 3.0]);
let mut gradients: Gradients = Default::default();
gradients.insert(&t, Box::new([-4.0, 5.0, -6.0]));
assert_eq!(gradients.remove(&t).unwrap().as_ref(), &[-4.0, 5.0, -6.0]);
sourcepub fn mut_gradient<T: HasUniqueId + HasArrayType + HasDevice>(
&mut self,
t: &T
) -> &mut T::Array
pub fn mut_gradient<T: HasUniqueId + HasArrayType + HasDevice>(
&mut self,
t: &T
) -> &mut T::Array
Returns a mutable reference to the data associated with t
.
If no data is associated with t
, then AllocateZeros::zeros is called
to allocate the data.
Example usage:
let t = Tensor1D::new([1.0, 2.0, 3.0]);
let mut gradients: Gradients = Default::default();
let g: &mut [f32; 3] = gradients.mut_gradient(&t);
assert_eq!(g, &mut [0.0, 0.0, 0.0]);
g[0] = 1.0;
assert_eq!(gradients.ref_gradient(&t), &[1.0, 0.0, 0.0]);
sourcepub fn ref_gradient<T: HasUniqueId + HasArrayType>(&self, t: &T) -> &T::Array
pub fn ref_gradient<T: HasUniqueId + HasArrayType>(&self, t: &T) -> &T::Array
Returns a reference to the data associated with t
.
Panics
If no data is associated with t
yet, this will panic due to an unwrap()
on a .get() to the underlying hashmap.
Example usage:
let t = Tensor1D::new([1.0, 2.0, 3.0]);
let mut gradients: Gradients = Default::default();
gradients.mut_gradient(&t);
assert_eq!(gradients.ref_gradient(&t), &[0.0, 0.0, 0.0]);
Trait Implementations
Auto Trait Implementations
impl !RefUnwindSafe for Gradients
impl !Send for Gradients
impl !Sync for Gradients
impl Unpin for Gradients
impl !UnwindSafe for Gradients
Blanket Implementations
sourceimpl<T> BorrowMut<T> for T where
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more