pub struct Array { /* private fields */ }
Expand description
An n-dimensional differentiable array. Stored in row-major order.
§Examples
use corgi::array::*;
let a = arr![arr![1.0, 2.0, 3.0], arr![4.0, 5.0, 6.0]];
let b = arr![arr![3.0, 2.0, 1.0], arr![6.0, 5.0, 4.0]];
let mut p = &a * &b;
p.backward(None);
Implementations§
Source§impl Array
impl Array
Source§impl Array
impl Array
Sourcepub fn reshape(&self, dimensions: Vec<usize>) -> Array
pub fn reshape(&self, dimensions: Vec<usize>) -> Array
Reshapes the array into different dimensions
Sourcepub fn axpy(alpha: Float, x: &Array, y: &Array) -> Array
pub fn axpy(alpha: Float, x: &Array, y: &Array) -> Array
Computes the element-wise alpha * x + y
, for each matching dimension not multiplied.
Sourcepub fn matmul(a: (&Array, bool), b: (&Array, bool), c: Option<&Array>) -> Array
pub fn matmul(a: (&Array, bool), b: (&Array, bool), c: Option<&Array>) -> Array
Computes matrix multiplications on two arrays, for each matching dimension not multiplied.
§Arguments
a
- The LHS matrix, and whether to transpose it: (a, a_transpose)
.
b
- The RHS matrix, and whether to transpose it: (b, b_transpose)
.
c
- The output matrix, if not initialized to a zeros matrix.
Source§impl Array
impl Array
Sourcepub fn dimensions(&self) -> &[usize]
pub fn dimensions(&self) -> &[usize]
Returns a copy of the dimensions of the array.
Sourcepub fn values(&self) -> &[Float] ⓘ
pub fn values(&self) -> &[Float] ⓘ
Returns an immutable reference to the values of the array in row-major order.
Sourcepub fn gradient(&self) -> Ref<'_, Option<Array>>
pub fn gradient(&self) -> Ref<'_, Option<Array>>
Returns a reference to the gradient option of the array.
Sourcepub fn gradient_mut(&self) -> RefMut<'_, Option<Array>>
pub fn gradient_mut(&self) -> RefMut<'_, Option<Array>>
Returns a mutable reference to the gradient option of the array.
Sourcepub fn replace_gradient(&self) -> Option<Array>
pub fn replace_gradient(&self) -> Option<Array>
Returns the owned gradient option of the array, replacing it with nothing.
Sourcepub fn tracked(self) -> Array
pub fn tracked(self) -> Array
Enables tracking of operations for the backward pass, meaning the backward pass will compute, and store gradients for the current array, and any children arrays which are tracked.
An operation with any positive number of tracked children will always output a tracked array.
This does not persist through threads, or through being set on a clone, meaning any tracked clones will not affect tracking of the array, apart from clones of the clone.
§Examples
// only the gradient for `b`, will be stored
let mut a = arr![1.0, 2.0, 3.0].untracked();
let b = arr![3.0, 2.0, 1.0].tracked();
let mut c = &a * &b;
c.backward(None);
assert_eq!(b.gradient().to_owned().unwrap(), arr![1.0, 2.0, 3.0]);
Sourcepub fn start_tracking(&self) -> bool
pub fn start_tracking(&self) -> bool
Starts tracking operations for a mutable reference to an array, returning the previous value.
Sourcepub fn untracked(self) -> Array
pub fn untracked(self) -> Array
Prevents tracking of operations for the backward pass, meaning the backward pass will skip computation of gradients for the current array, and any children arrays.
Any operation with every child untracked will always output an untracked array, and will not store any subgraph information.
This does not persist through threads, or through being set on a clone, meaning any tracked clones will not affect tracking of the array, apart from clones of the clone.
§Examples
// only the gradient for `b`, will be stored
let mut a = arr![1.0, 2.0, 3.0].untracked();
let b = arr![3.0, 2.0, 1.0].tracked();
let mut c = &a * &b;
c.backward(None);
assert_eq!(b.gradient().to_owned().unwrap(), arr![1.0, 2.0, 3.0]);
Sourcepub fn stop_tracking(&self) -> bool
pub fn stop_tracking(&self) -> bool
Stops tracking operations for a mutable reference to an array, returning the previous value. Useful for temporarily updating parameters without requiring their gradients.
Sourcepub fn backward(&self, delta: Option<Array>)
pub fn backward(&self, delta: Option<Array>)
Computes the backward pass, computing gradients for all descendants, and propagating consumer counts if requested.
§Panics
Panics if the current node has children, but is not a differentiable function (is not a leaf).
Sourcepub fn op(
arrays: &[&Array],
op: ForwardOp,
backward_op: Option<BackwardOp>,
) -> Array
pub fn op( arrays: &[&Array], op: ForwardOp, backward_op: Option<BackwardOp>, ) -> Array
Computes an operation on arrays.
§Arguments
arrays
- The arrays to perform the operations on.op
- TheForwardOp
, which takes in the arrays, and outputs another array.backward_op
- TheBackwardOp
, which takes in the arrays, and the delta, and outputs a new delta, with respect to each input. It is recommended that any array operations here are untracked, unless interested in higher order derivatives.
§Examples
let mul: ForwardOp = Rc::new(|x: &[&Array]| {
Array::from((x[0].dimensions().to_vec(), x[0].values().iter().zip(x[1].values()).map(|(x, y)| x * y).collect::<Vec<Float>>()))
});
let mul_clone = Rc::clone(&mul);
let backward_op: BackwardOp = Rc::new(move |children, is_tracked, delta| {
vec![
if is_tracked[0] {
Some(Array::op(&[&children[1], delta], Rc::clone(&mul_clone), None))
} else {
None
},
if is_tracked[1] {
Some(Array::op(&[&children[0], delta], Rc::clone(&mul_clone), None))
} else {
None
}
]
});
let a = arr![1.0, 2.0, 3.0].tracked();
let b = arr![3.0, 2.0, 1.0].tracked();
let product = Array::op(&vec![&a, &b], mul, Some(backward_op));
assert_eq!(product, arr![3.0, 4.0, 3.0]);
product.backward(None);
assert_eq!(product.gradient().to_owned().unwrap(), arr![1.0, 1.0, 1.0]);
assert_eq!(b.gradient().to_owned().unwrap(), arr![1.0, 2.0, 3.0]);
assert_eq!(a.gradient().to_owned().unwrap(), arr![3.0, 2.0, 1.0]);
Trait Implementations§
Source§impl AbsDiffEq for Array
impl AbsDiffEq for Array
Source§fn default_epsilon() -> <Float as AbsDiffEq>::Epsilon
fn default_epsilon() -> <Float as AbsDiffEq>::Epsilon
Source§fn abs_diff_eq(
&self,
other: &Array,
epsilon: <Float as AbsDiffEq>::Epsilon,
) -> bool
fn abs_diff_eq( &self, other: &Array, epsilon: <Float as AbsDiffEq>::Epsilon, ) -> bool
Source§fn abs_diff_ne(&self, other: &Rhs, epsilon: Self::Epsilon) -> bool
fn abs_diff_ne(&self, other: &Rhs, epsilon: Self::Epsilon) -> bool
AbsDiffEq::abs_diff_eq
.Source§impl From<(Vec<usize>, Vec<f64>)> for Array
Implementation to construct Array
structs by using Vec<usize>
as the dimensions, and Vec<Float>
as the values.
impl From<(Vec<usize>, Vec<f64>)> for Array
Implementation to construct Array
structs by using Vec<usize>
as the dimensions, and Vec<Float>
as the values.
§Examples
let a = Array::from((vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]));
assert_eq!(a[vec![1, 2]], 6.0);
Source§impl From<Vec<Array>> for Array
Implementation to construct Array
structs by flattening other contained Array
structs, and keeping
their dimensions.
impl From<Vec<Array>> for Array
Implementation to construct Array
structs by flattening other contained Array
structs, and keeping
their dimensions.
§Examples
let a = Array::from(vec![arr![1.0, 2.0, 3.0], arr![4.0, 5.0, 6.0]]);
assert_eq!(a[vec![1, 2]], 6.0);
Source§impl From<Vec<f64>> for Array
Implementation to construct Array
structs by using Vec<Float>
as the values, and by keeping flat dimensions.
impl From<Vec<f64>> for Array
Implementation to construct Array
structs by using Vec<Float>
as the values, and by keeping flat dimensions.
§Examples
let a = Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(a[vec![5]], 6.0);
Source§impl From<Vec<usize>> for Array
Implementation to construct Array
structs by using Vec<usize>
as the dimensions, and filling values with zeros.
impl From<Vec<usize>> for Array
Implementation to construct Array
structs by using Vec<usize>
as the dimensions, and filling values with zeros.
§Examples
let a = Array::from(vec![3, 2, 3]);
assert_eq!(a[vec![2, 1, 1]], 0.0);
Source§impl RelativeEq for Array
impl RelativeEq for Array
Source§fn default_max_relative() -> <Float as AbsDiffEq>::Epsilon
fn default_max_relative() -> <Float as AbsDiffEq>::Epsilon
Source§fn relative_eq(
&self,
other: &Array,
epsilon: <Float as AbsDiffEq>::Epsilon,
max_relative: <Float as AbsDiffEq>::Epsilon,
) -> bool
fn relative_eq( &self, other: &Array, epsilon: <Float as AbsDiffEq>::Epsilon, max_relative: <Float as AbsDiffEq>::Epsilon, ) -> bool
Source§fn relative_ne(
&self,
other: &Rhs,
epsilon: Self::Epsilon,
max_relative: Self::Epsilon,
) -> bool
fn relative_ne( &self, other: &Rhs, epsilon: Self::Epsilon, max_relative: Self::Epsilon, ) -> bool
RelativeEq::relative_eq
.