Struct autograd::Graph[][src]

pub struct Graph<F: Float> { /* fields omitted */ }
Expand description

Generator of Tensor objects.

Use autograd::with to instantiate this.

use autograd as ag;

ag::with(|graph1: &mut ag::Graph<f32>| {
    // Creating some nodes (tensors) in this graph.
    let a = graph1.zeros(&[2, 3]);
    let b = graph1.ones(&[2, 3]);

    // Evaluate the tensors
    (a + b).eval(&[]);

    // Creating another scope (graph).
    ag::with(|graph2: &mut ag::Graph<f32>| {
        // `c` is valid only in graph2.
        let c = graph2.zeros(&[3, 4]);

        // Cross-scope access to what derived from `Graph` can't compile for now.

        // graph1.zeros(&[2, 3])
        // ^^^^^^ invalid access for `graph1`

        // a + c
        // ^ invalid access for `a` that belongs to ``graph1`
    });
    // tensors in graph2 destructed here.
});
// tensors in graph1 destructed here.

Implementations

Symbolic gradient tensors of xs in the same order as xs’s

Arguments
  • ys - Targets of differentiation that are arbitrary shapes.
  • xs - Tensors with which differentiate ys.
Example

Partial derivatives of z = 2x^2 + 3y + 1.

use ndarray;
use autograd as ag;

ag::with(|g| {
    let x = g.placeholder(&[]);
    let y = g.placeholder(&[]);
    let z = 2.*x*x + 3.*y + 1.;

    // dz/dy
    let gy = g.grad(&[z], &[y])[0];
    // dz/dx
    let gx = g.grad(&[z], &[x])[0];

    // ddz/dx (differentiates `z` again)
    let ggx = g.grad(&[gx], &[x])[0];

    // evaluation of symbolic gradients
    assert_eq!(3., gy.eval(&[]).unwrap()[ndarray::IxDyn(&[])]);
    assert_eq!(4., ggx.eval(&[]).unwrap()[ndarray::IxDyn(&[])]);

    // dz/dx requires to fill the placeholder `x`
    assert_eq!(8., gx.eval(&[x.given(ndarray::arr0(2.).view())]).unwrap()[ndarray::IxDyn(&[])]);
});

Computes xs’s gradients with ys’s already known gradients.

Almost same spec as grad’s except that you can pass yss already known gradients. If ys_grads are tensors filled with 1s, this function should be replaced with grad.

NOTE: Please be careful to match ys_grads[i].shape and ys[i].shape, otherwise undefined behavior would happen.

Arguments
  • ys - Targets of differentiation.
  • xs - tensors with which differentiate ys.
  • ys_grads - Already known gradients of ys.
Returns

Symbolic gradient tensors of xs in the same order as xs’graph.

Computes jacobians for variables.

Arguments
  • y - Target of differentiation.
  • xs - Tensors with which differentiate ys.
  • y_size - (flattened) size of y
Returns

Jacobians for each variable. Each one is a matrix of shape (y_size, x size).

Note: the current implementation works correctly but is unoptimized for serious use.

use autograd as ag;
use ag::tensor::Variable;

ag::with(|g| {
   let rng = ag::ndarray_ext::ArrayRng::<f32>::default();
   let a = g.variable(rng.standard_normal(&[4, 2]));
   let b = g.variable(rng.standard_normal(&[2, 3]));
   let c = g.matmul(a, b);
   let j = g.jacobians(c, &[a, b], 4*3);

   assert_eq!(j[0].eval(&[]).unwrap().shape(), &[4*3, 4*2]);
   assert_eq!(j[1].eval(&[]).unwrap().shape(), &[4*3, 2*3]);
});

(Experimental) Computes hessian vector product

Stops gradient propagation.

Guarantees that the gradient is not propagated to the tensors behind this during gradient computation.

Creates a placeholder tensor.

Behaves like TensorFlow’s placeholder object. shape_[i] must be a positive value, or -1 which means dynamic dim.

use ndarray;
use autograd as ag;

ag::with(|g| {
    let x = g.placeholder(&[2]);

    // Fills placeholder, then eval
    let arr = ndarray::array![1., 1.].into_dyn();
    assert_eq!(x.eval(&[x.given(arr.view())]), Ok(arr));
});

Returns a Tensor representation of the input tensor’s shape

use autograd as ag;

ag::with(|g| {
   let x: ag::Tensor<f32> = g.zeros(&[2, 3]);
   let s = g.shape(x);
   assert_eq!(&[2., 3.], s.eval(&[]).unwrap().as_slice().unwrap());
});

Returns the (symbolic) size of the input tensor

use ndarray;
use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[4, 3]);
   let b = g.size(a);

   assert_eq!(12., b.eval(&[]).unwrap()[ndarray::IxDyn(&[])]);
});

Returns the (symbolic) rank of the input tensor

use ndarray;
use autograd as ag;

ag::with(|g| {
   let x: ag::Tensor<f32> = g.zeros(&[2, 3, 4]);
   let r = g.rank(x);
   assert_eq!(3., r.eval(&[]).unwrap()[ndarray::IxDyn(&[])]);
});

Elementwise sine

Elementwise cosine

Elementwise tangent

Elementwise arcsin

Elementwise arccos

Elementwise arctan

Elementwise hyperbolic sine

Elementwise hyperbolic cosine

Elementwise hyperbolic tangent

Elementwise hyperbolic arcsin

Elementwise hyperbolic arccos

Elementwise hyperbolic arctan

Identity function without copy.

Elementwise addition.

This can be replaced with + operation of Tensor.

Element-wise subtraction.

This can be replaced with - operation of Tensor.

Elementwise multiplication.

This can be replaced with * operation of Tensor.

Elementwise division.

This can be replaced with / operation of Tensor.

Elementwise sqrt

Elementwise pow

Elementwise base e (napier) logarithm

Elementwise base 2 logarithm

Elementwise base 10 logarithm

Elementwise base e (napier) exponential

Elementwise base 2 exponential

Elementwise base 10 exponential

Returns the max of x and y (i.e. x > y ? x : y) element-wise.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![1., 2., 3.]);
   let b = g.constant(array![3., 2., 1.]);
   let c = g.maximum(a, b);
   assert_eq!(c.eval(&[]), Ok(array![3., 2., 3.].into_dyn()));
});

Returns the min of x and y (i.e. x > y ? y : x) element-wise.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![1., 2., 3.]);
   let b = g.constant(array![3., 2., 1.]);
   let c = g.minimum(a, b);
   assert_eq!(c.eval(&[]), Ok(array![1., 2., 1.].into_dyn()));
});

Adds all input tensors, element-wise.

All the input tensors must have same shapes.

use ndarray::array;
use autograd as ag;

ag::with(|g| {
   let a = g.ones(&[2, 2]);
   let b = g.ones(&[2, 2]);
   let c = g.ones(&[2, 2]);
   let d = g.add_n(&[a, b, c]);

   assert_eq!(d.eval(&[]).unwrap().shape(), &[2, 2]);
   assert_eq!(d.eval(&[]), Ok(array![[3., 3.], [3., 3.]].into_dyn()));
});

Compares a couple of tensors and returns a binary tensor.

if a[i] == b[i] then return-value[i] will be 1 else 0

Panics

When broadcast is impossible

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![1., 2., 3.]);
   let b = g.constant(array![3., 2., 1.]);
   let c = g.equal(a, b);
   assert_eq!(c.eval(&[]), Ok(ndarray::arr1(&[0., 1., 0.]).into_dyn()));
});

Compares a couple of tensors and returns a binary tensor.

if a[i] != b[i] then return-value[i] will be 1 else 0

Panics

When broadcast is impossible

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![1., 2., 3.]);
   let b = g.constant(array![3., 2., 1.]);
   let c = g.not_equal(a, b);
   assert_eq!(c.eval(&[]), Ok(array![1., 0., 1.].into_dyn()));
});

Takes argmin along specified axis.

axis can be negative.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[3., 4.], [6., 5.]]);
   let y = g.argmin(x, 1, false);

   assert_eq!(y.eval(&[]), Ok(array![0., 1.].into_dyn()));
});

Takes argmax along specified axis.

axis can be negative.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[3., 4.], [6., 5.]]);
   let y = g.argmax(x, 1, false);

   assert_eq!(y.eval(&[]), Ok(array![1., 0.].into_dyn()));
});

Expands the shape (inserts axes).

Each axis can be negative.

use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[3]);
   let b = g.expand_dims(a, &[0, 2]);
   assert_eq!(b.eval(&[]).unwrap().shape(), &[1, 3, 1]);
});

Remove the specified dims.

Each axis can be negative.

use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[1, 3, 1]);
   let b = g.squeeze(a, &[0, 2]);
   assert_eq!(b.eval(&[]).unwrap().shape(), &[3]);
})

Tiles the input tensor along specified axis.

Tiles input tensor num times along axis. axis can be negative.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[2., 2.], [3., 3.]]);
   let y = g.tile(x, 0, 2);

   assert_eq!(
       y.eval(&[]),
       Ok(array![[2., 2.], [3., 3.], [2., 2.], [3., 3.]].into_dyn())
   );
});

Limits all elements of x so as to be within [min, max]

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![2., 4., 6.]);
   let y = g.clip(x, 3., 5.);
   assert_eq!(y.eval(&[]), Ok(ndarray::arr1(&[3., 4., 5.]).into_dyn()));
});

Takes max along specified axes.

Each of element of axes can be negative.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[2., 4.], [3., 1.]]);
   let y = g.reduce_max(&x, &[0], false);
   assert_eq!(y.eval(&[]), Ok(array![3., 4.].into_dyn()));
});

Takes min along specified axes.

Each of element of axes can be negative.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[2., 4.], [3., 1.]]);
   let y = g.reduce_min(&x, &[0], false);
   assert_eq!(y.eval(&[]), Ok(array![2., 1.].into_dyn()));
});

Sum up all the elements to a scalar value (0-D Tensor).

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[2., 4.], [3., 1.]]);
   let y = g.reduce_sum_to_scalar(&x);
   assert_eq!(y.eval(&[]), Ok(ndarray::arr0(10.).into_dyn()));
});

Takes sumation along specified axes.

Elements of axes can be negative.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[2., 4.], [3., 1.]]);
   let y = g.reduce_sum(&x, &[1], false);

   assert_eq!(y.eval(&[]), Ok(array![6., 4.].into_dyn()));
});

Takes mean along specified axes.

Elements of axes can be negative.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[2., 4.], [3., 1.]]);
   let y = g.reduce_mean(x, &[1], false);
   assert_eq!(y.eval(&[]), Ok(array![3., 2.].into_dyn()));
});

Takes product along specified axes.

Elements of axes can be negative.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[2., 4.], [3., 1.]]);
   let y = g.reduce_prod(&x, &[1], false);
   assert_eq!(y.eval(&[]), Ok(array![8., 3.].into_dyn()));
});

Compute population variance along specified axes.

Elements of axes can be negative.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let x = g.constant(array![[1., 1.], [2., 2.]]);
   let y = g.reduce_variance(&x, &[1], false);
   assert_eq!(y.eval(&[]), Ok(array![0., 0.].into_dyn()));
});

Reshapes the input tensor without copy.

Only one element in shape can be -1.

use ndarray;
use autograd as ag;

ag::with(|g| {
   let x: ag::Tensor<f32> = g.zeros(&[3, 2, 2]);
   let y = g.reshape(&x, &[3, -1]);
   assert_eq!(y.eval(&[]), Ok(ag::ndarray_ext::zeros::<f32>(&[3, 4])));
});

Flattens the input tensor into 1-ranked (vector) without copy.

use autograd as ag;

ag::with(|g| {
   let x: ag::Tensor<f32> = g.zeros(&[3, 2, 2]);
   let z = g.flatten(x);
   assert_eq!(z.eval(&[]).unwrap().shape(), &[12]);
});

Returns -1 if x < 0, 0 if x==0, 1 if x > 0, element-wise.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![-5., 4.5, 0.]);
   let b = g.sign(a);
   assert_eq!(
       b.eval(&[]).unwrap().as_slice().unwrap(),
       &[-1., 1., 0.]
   );
});

Returns the largest integer less than or equal to a number, element-wise.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![-0.2, 0., 0.2]);
   let b = g.abs(a);
   assert_eq!(
       b.eval(&[]),
       Ok(ndarray::arr1(&[0.2, 0., 0.2]).into_dyn())
   );
});

Returns the largest integer less than or equal to a number, element-wise.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]);
   let b = g.floor(a);
   assert_eq!(
       b.eval(&[]),
       Ok(array![-2., -2., -1.,  0.,  1.,  1.,  2.].into_dyn())
   );
});

Performs the - operation.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![2., 3.]);
   let b = g.neg(a);
   assert_eq!(
       b.eval(&[]),
       Ok(array![-2., -3.].into_dyn())
   );
});

Takes square of the input.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![2., 3.]);
   let b = g.square(a);
   assert_eq!(
       b.eval(&[]),
       Ok(array![4., 9.].into_dyn())
   );
});

Returns the 1/x, element-wise.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![2.]);
   let b = g.inv(a);
   assert_eq!(
       b.eval(&[]),
       Ok(array![0.5].into_dyn())
   );
});

Returns the 1/sqrt(x), element-wise.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![4.]);
   let b = g.inv_sqrt(a);
   assert_eq!(
       b.eval(&[]),
       Ok(array![0.5].into_dyn())
   );
});

Returns the smallest integer greater than or equal to a number, element-wise.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]);
   let b = g.ceil(a);
   assert_eq!(
       b.eval(&[]),
       Ok(array![-1., -1., -0.,  1.,  2.,  2.,  2.].into_dyn())
   );

});

Compares a couple of tensors and returns a binary tensor.

Panics

When broadcast is impossible

Compares a couple of tensors and returns a binary tensor.

Panics

When broadcast is impossible

Compares a couple of tensors and returns a binary tensor.

Panics

When broadcast is impossible

Compares a couple of tensors and returns a binary tensor.

Panics

When broadcast is impossible

Elementwise logistic sigmoid function.

Elementwise exponential linear unit.

See https://arxiv.org/abs/1511.07289

Elementwise rectified linear unit.

Elementwise leaky relu.

In common, alpha is around 0.1 ~ 0.2.

See http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf

Elementwise softplus.

Computes log(sum(exp(x))) along specified axis.

axis can be negative.

Log softmax function.

Computes softmax(x) along specified axis and takes logarithm of it. axis can be negative.

Computes softmax along specified axis

axis can be negative.

Computes binary_cross_entropy(sigmoid(y), t).

This function is better than that combination in that it can prevent underflow of log(sigmoid).

Arguments
  • y - Tensor with arbitrary shape
  • t - Ground-truth Tensor with same shape as y’graph
Panics

When y.shape != t.shape.

Returns

Loss tensor with same shape as inputs’s shapes

Computes categorical_cross_entropy(softmax(y), t).

This function is better than that combination in that it can prevent underflow of log(softmax).

Arguments
  • y - Tensor with shape (batch_size, num_classes)
  • t - Tensor with shape (batch_size, num_classes)
Returns

Loss tensor with shape (batch_size, 1)

A variant of softmax_cross_entropy.

The behavior of this function is same as softmax_cross_entropy except that t is not batch of one-hot distributions but batch of ground truth label ids.

Arguments
  • y - Tensor with shape (batch_size, num_classes)
  • t - Tensor with shape (batch_size,) or (batch_size, 1)
Returns

Loss tensor with shape (batch_size, 1)

Matrix multiplication.

Both a and b must be 2-ranked tensors.

use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[4, 2]);
   let b: ag::Tensor<f32> = g.zeros(&[2, 3]);
   let c = g.matmul(a, b);
   assert_eq!(c.eval(&[]).unwrap().shape(), &[4, 3]);
});

This function supports only f32 and f64.

Computes tensor-dot-product (tensor contraction) along specified axes.

Arguments
  • a - First input tensor
  • b - Second input tensor
  • a_axes - a’s Contraction axes
  • b_axes - b’s Contraction axes

NOTE:

  • length of a_axes and b_axes must match.
  • Each axis number can be negative.
  • Supports only f32 and f64.
use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[3, 4, 5]);
   let b: ag::Tensor<f32> = g.zeros(&[4, 3, 2]);
   let c = g.tensordot(a, b, &[1, 0], &[0, 1]);
   assert_eq!(c.eval(&[]).unwrap().shape(), &[5, 2]);
});

For detailed description, see https://docs.scipy.org/doc/numpy/reference/generated/numpy.tensordot.html.

Batched matrix multiplication with inputs’s transposition.

The rank of a and b must be equals.

use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[2, 3, 2, 4]);
   let b: ag::Tensor<f32> = g.zeros(&[2, 3, 2, 3]);
   let c = g.batch_matmul_t(a, b, true, false);
   assert_eq!(c.eval(&[]).unwrap().shape(), &[2, 3, 4, 3]);
});

This function supports only f32 and f64. For detailed description, see https://www.tensorflow.org/api_docs/python/tf/matmul

Batched matrix multiplication.

The rank of a and b must be equals.

use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.ones(&[2, 3, 4, 2]);
   let b: ag::Tensor<f32> = g.ones(&[2, 3, 2, 3]);
   let c = g.batch_matmul(a, b);
   assert_eq!(c.eval(&[]).unwrap().shape(), &[2, 3, 4, 3]);
});

This function supports only f32 and f64. For detailed description, see https://www.tensorflow.org/api_docs/python/tf/matmul

Takes diff between two tensors.

Returns the sorted, unique values in a that are not in b.

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let a = g.constant(array![4., 1., 5., 2., 3., 6.]);
   let b = g.constant(array![[2., 3.], [1., 4.]]);
   let c = g.setdiff1d(a, b);
   assert_eq!(
       c.eval(&[]),
       Ok(ndarray::arr1(&[5., 6.]).into_dyn())
   )
});

Permutes dimensions without copy.

It’s like TensorFlow or NumPy’s. x’s rank (ndim) and axes.len() must match.

use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[1, 2, 3, 4, 5]);
   let b = g.transpose(a, &[4, 2, 3, 0, 1]);
   assert_eq!(b.eval(&[]).unwrap().shape(), &[5, 3, 4, 1, 2]);
});

Splits input tensors into parts.

Splits x into sizes.len() parts along axis.

The size of dimension of each part is sizes[i] on axis, but is x.shape[i] on other axis (similar to TensorFlow’s split).

use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[3, 7, 5]);
   let b = g.split(a, &[2, 3, 2], 1);

   let evaluated = g.eval(&[&b[0], &b[1], &b[2]], &[]);
   let e0 = &evaluated[0];
   let e1 = &evaluated[1];
   let e2 = &evaluated[2];

   assert_eq!(e0.as_ref().unwrap().shape(), &[3, 2, 5]);
   assert_eq!(e1.as_ref().unwrap().shape(), &[3, 3, 5]);
   assert_eq!(e2.as_ref().unwrap().shape(), &[3, 2, 5]);
});

Slices the input tensor.

Arguments
  • x - Tensor with arbitrary shape.
  • starts - Inclusive start indices for the dimensions.
  • ends - End indices for the dimensions. Each index is inclusive if it is negative and exclusive if it’s not.

NOTE: Negative values in starts and ends are counted from the back of the axis.

use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[4, 4]);
   let b = g.slice(a, &[0, 0], &[-1, 2]); // numpy equivalent is a[:, 0:2]

   assert_eq!(b.eval(&[]).unwrap().shape(), &[4, 2]);
});
use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[4, 4]);
   let b = g.slice(a, &[0, 0], &[-2, 2]); // numpy equivalent is a[:-1, :2]

   assert_eq!(b.eval(&[]).unwrap().shape(), &[3, 2]);
});

Concatenates input tensors along specified axis.

axis can be negative.

use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[3, 2]);
   let b: ag::Tensor<f32> = g.zeros(&[3, 2]);
   let c: ag::Tensor<f32> = g.zeros(&[3, 2]);
   let d = g.concat(&[a, b, c], 0);

   assert_eq!(d.eval(&[]).unwrap().shape(), &[9, 2]);
});

Gathers subviews from the input tensor.

Same spec as https://www.tensorflow.org/api_docs/python/tf/gather. For example, this can be used for embedding vectors lookup etc.

Unlike ag::gather, indices can contain negative elements.

Returns

Tensor with shape param.shape[..axis] + indices.shape + param.shape[axis+1..]

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let param = g.zeros(&[5, 4, 8, 2]);
   let indices = g.constant(array![[5., -1., 3.], [2., 1., -2.]]);
   let y = g.gather_common(param, indices, 2);

   assert_eq!(y.eval(&[]).unwrap().shape(), &[5, 4, 2, 3, 2])
});

Gathers subviews from the input tensor.

Same spec as https://www.tensorflow.org/api_docs/python/tf/gather. For example, this can be used for embedding vectors lookup etc.

Returns

Tensor with shape param.shape[..axis] + indices.shape + param.shape[axis+1..]

use ndarray::array;
use autograd as ag;
use ag::tensor::Constant;

ag::with(|g| {
   let param = g.zeros(&[5, 4, 8, 2]);
   let indices = g.constant(array![[5., 4., 3.], [2., 1., 0.]]);  // shape: (2, 3)
   let y = g.gather(param, indices, 2);

   assert_eq!(y.eval(&[]).unwrap().shape(), &[5, 4, 2, 3, 2])
});

Normalizes the input tensor with its mean and variance along specified axis.

use autograd as ag;

ag::with(|g| {
   let x: ag::Tensor<f32> = g.standard_normal(&[3, 4]);
   let y1 = g.normalize(x, &[0]);
   let y2 = g.normalize(x, &[0]);

   let evaluated = g.eval(&[y1, y2], &[]);
   let e0 = &evaluated[0];
   let e1 = &evaluated[1];
   assert_eq!(e0.as_ref().unwrap().shape(), &[3, 4]);
   assert_eq!(e1.as_ref().unwrap().shape(), &[3, 4]);
});

Applies batch normalization.

scale and shift should be shared variables. Since normalization is performed along 1st axis of x, both of them should have shape (1, x.shape[1])

use autograd as ag;
use ag::tensor::Variable;

ag::with(|g| {
   let x = g.standard_normal(&[3, 4]);
   let scale = g.variable(ag::ndarray_ext::ones::<f32>(&[1, 4]));
   let shift = g.variable(ag::ndarray_ext::zeros::<f32>(&[1, 4]));
   let norm = g.batch_norm(x, scale, shift);

   assert_eq!(norm.eval(&[]).unwrap().shape(), &[3, 4]);
});

Generates a zero-ranked tensor from a scalar value.

use autograd as ag;

ag::with(|g| {
   let a = g.scalar(3.);
   println!("{}", a.eval(&[]).unwrap());  // => 3.
   assert_eq!(a.eval(&[]).unwrap().shape(), &[]);
});

Outputs values sampled from the normal distribution.

Outputs values sampled from the normal distribution.

Pre-instantiated ArrayRng is acceptable.

Outputs values sampled from the uniform distribution.

Outputs values sampled from the uniform distribution.

Pre-instantiated ArrayRng is acceptable.

Outputs values sampled from the standard normal distribution.

Outputs values sampled from the standard normal distribution.

Pre-instantiated ArrayRng is acceptable.

Outputs values sampled from the standard uniform distribution.

Outputs values sampled from the standard uniform distribution.

Pre-instantiated ArrayRng is acceptable.

Outputs values sampled from the bernoulli distribution.

Outputs values sampled from the bernoulli distribution.

Pre-instantiated ArrayRng is acceptable.

Outputs values sampled from the exponential distribution.

Outputs values sampled from the exponential distribution.

Pre-instantiated ArrayRng is acceptable.

Outputs values sampled from the gamma distribution.

Outputs values sampled from the gamma distribution.

Pre-instantiated ArrayRng is acceptable.

Outputs values sampled from the log-normal distribution.

Outputs values sampled from the log-normal distribution.

Pre-instantiated ArrayRng is acceptable.

Converts an ndarray::Array to a ag::Tensor.

use ndarray::array;
use autograd as ag;

ag::with(|g| {
   let arr = array![2., 3.];
   let tensor = g.convert_to_tensor(arr.clone());
   assert_eq!(tensor.eval(&[]), Ok(arr.into_dyn()));
});

Returns zeros with given shape.

use ndarray;
use autograd as ag;

ag::with(|g| {
   let a: ag::Tensor<f32> = g.zeros(&[4, 2]);
   assert_eq!(a.eval(&[]), Ok(ndarray::Array2::<f32>::zeros((4, 2)).into_dyn()));
});

Returns ones with given shape.

use ndarray;
use autograd as ag;

ag::with(|g| {
   let a = g.ones(&[4, 2]);
   assert_eq!(a.eval(&[]), Ok(ndarray::Array2::<f32>::ones((4, 2)).into_dyn()));
});

2D convolution.

  • x: Tensor with shape (batch, channel, h, w)
  • w: Tensor with shape (out_channel, channel, filter_h, filter_w)

Returns a tensor with shape (batch, out_channel, out_h, out_w)

where

  • out_h = (h + 2 * pad - filter_h) / stride + 1
  • out_w = (w + 2 * pad - filter_w) / stride + 1

This function supports only f32 and f64.

2D convolution with dilation.

  • x: Tensor with shape (batch, channel, h, w)
  • w: Tensor with shape (out_channel, in_channel, filter_h, filter_w)

Returns a tensor with shape (batch, out_channel, out_h, out_w)

where

  • out_h = (h + 2 * pad - (dilate * (filter - 1) + 1)) / stride + 1
  • out_w = (w + 2 * pad - (dilate * (filter - 1) + 1)) / stride + 1

This function supports only f32 and f64.

2D transposed convolution.

  • x: Tensor with shape (batch, in_channel, h, w)
  • w: Tensor with shape (in_channel, out_channel, filter_h, filter_w)

Returns a tensor with shape (batch, out_channel, out_h, out_w)

where

  • out_h = stride * (h - 1) - pad + filter_h
  • out_w = stride * (w - 1) - pad + filter_w

This function supports only f32 and f64.

2D transposed convolution with dilation.

  • x: Tensor with shape (batch, in_channel, h, w)
  • w: Tensor with shape (in_channel, out_channel, filter_h, filter_w)

Returns a tensor with shape (batch, out_channel, out_h, out_w)

where

  • out_h = stride * (h - 1) - pad + (dilate * (filter_h - 1) + 1)
  • out_w = stride * (w - 1) - pad + (dilate * (filter_w - 1) + 1)

This function supports only f32 and f64.

2D max pooling.

  • x: Tensor with shape (batch, channel, h, w)

Returns a tensor with shape (batch, channel, out_h, out_w)

where

  • out_h = (h + 2 * pad - pool_size) / stride + 1
  • out_w = (w + 2 * pad - pool_size) / stride + 1

This function supports only f32 and f64.

Evaluates given symbolic tensors as a list of ndarray::Array<F, ndarray::IxDyn>.

Unlike Tensor::eval, this function supports batched evaluation.

See also Eval.

use ndarray::array;
use autograd as ag;

ag::with(|g| {
    let a = g.zeros(&[2]);
    let b = g.ones(&[2]);

    // eval two tensors at once.
    let evaluated = g.eval(&[a, b], &[]);
    assert_eq!(evaluated[0], Ok(array![0., 0.].into_dyn()));
    assert_eq!(evaluated[1], Ok(array![1., 1.].into_dyn()));
});

Trait Implementations

Creates a (persistent) constant tensor from an NdArray, or Arc<NdArray> to prevent move. Read more

Creates a (persistent) constant tensor from an NdArray, or Arc<NdArray> to prevent move. Read more

Creates a (persistent) constant tensor from an NdArray, or Arc<NdArray> to prevent move. Read more

Creates a (persistent) constant tensor from an NdArray, or Arc<NdArray> to prevent move. Read more

Creates a (persistent) constant tensor from an NdArray, or Arc<NdArray> to prevent move. Read more

Creates a (persistent) constant tensor from an NdArray, or Arc<NdArray> to prevent move. Read more

Creates a (persistent) constant tensor from an NdArray, or Arc<NdArray> to prevent move. Read more

Creates a (persistent) constant tensor from an NdArray, or Arc<NdArray> to prevent move. Read more

Creates a (persistent) constant tensor from an NdArray, or Arc<NdArray> to prevent move. Read more

Formats the value using the given formatter. Read more

Creates a shared variable tensor from an NdArray, or Arc<RwLock<NdArray>> to prevent move. Read more

Creates a shared variable tensor from an NdArray, or Arc<RwLock<NdArray>> to prevent move. Read more

Creates a shared variable tensor from an NdArray, or Arc<RwLock<NdArray>> to prevent move. Read more

Creates a shared variable tensor from an NdArray, or Arc<RwLock<NdArray>> to prevent move. Read more

Creates a shared variable tensor from an NdArray, or Arc<RwLock<NdArray>> to prevent move. Read more

Creates a shared variable tensor from an NdArray, or Arc<RwLock<NdArray>> to prevent move. Read more

Creates a shared variable tensor from an NdArray, or Arc<RwLock<NdArray>> to prevent move. Read more

Creates a shared variable tensor from an NdArray, or Arc<RwLock<NdArray>> to prevent move. Read more

Creates a shared variable tensor from an NdArray, or Arc<RwLock<NdArray>> to prevent move. Read more

Auto Trait Implementations

Blanket Implementations

Gets the TypeId of self. Read more

Immutably borrows from an owned value. Read more

Mutably borrows from an owned value. Read more

Performs the conversion.

Performs the conversion.

The alignment of pointer.

The type for initializers.

Initializes a with the given initializer. Read more

Dereferences the given pointer. Read more

Mutably dereferences the given pointer. Read more

Drops the object pointed to by the given pointer. Read more

The type returned in the event of a conversion error.

Performs the conversion.

The type returned in the event of a conversion error.

Performs the conversion.